diff --git a/.gitignore b/.gitignore index 09734fe4974935956fd599f7f86cd5c4d195d5e2..9ae0d9c96f188bc6357832f22b4125694302b104 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,8 @@ cmake_build/ .idea/** /build/ /tensorflow/core/util/version_info.cc +/tensorflow/python/framework/fast_tensor_util.cpp +Pods +Podfile.lock +*.pbxproj +*.xcworkspacedata diff --git a/CODEOWNERS b/CODEOWNERS index 6e4b4f5f3f751ca9ab39a5772458349b00f06d57..57a4df40e651f45dc03493af631d73332e46c182 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -11,6 +11,7 @@ # NEED OWNER: tensorflow/contrib/avro/* #tensorflow/contrib/batching/* @alextp @chrisolston #tensorflow/contrib/bayesflow/* @ebrevdo @rsepassi @jvdillon +#tensorflow/contrib/boosted_trees/* @sshrdp @yk5 @nataliaponomareva #tensorflow/contrib/cmake/* @mrry @benoitsteiner #tensorflow/contrib/copy_graph/* @tucker @poxvoculi #tensorflow/contrib/crf/* @kentonl diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 10fd595fec7f240c3fdc871e1f32cc83f2ffd46d..cfc45049f7088e95059d2e07d5c8ce98f32def93 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -42,7 +42,7 @@ The Code of Conduct also applies within project spaces and in public spaces when Conflicts in an open source project can take many forms, from someone having a bad day and using harsh and hurtful language in the issue queue, to more serious instances such as sexist/racist statements or threats of violence, and everything in between. -If the behaviour is threatening or harassing, or for other reasons requires immediate escalation, please see below. +If the behavior is threatening or harassing, or for other reasons requires immediate escalation, please see below. However, for the vast majority of issues, we aim to empower individuals to first resolve conflicts themselves, asking for help when needed, and only after that fails to escalate further. This approach gives people more control over the outcome of their dispute. diff --git a/ISSUE_TEMPLATE.md b/ISSUE_TEMPLATE.md index 2bf2c754cf64ec3bac22a22fbafcebbd4dc54bf4..1a401997c649518766acb2ebb0dea1c128bd0ba4 100644 --- a/ISSUE_TEMPLATE.md +++ b/ISSUE_TEMPLATE.md @@ -19,6 +19,7 @@ If you open a GitHub issue, here is our policy: - **TensorFlow version (use command below)**: - **Python version**: - **Bazel version (if compiling from source)**: +- **GCC/Compiler version (if compiling from source)**: - **CUDA/cuDNN version**: - **GPU model and memory**: - **Exact command to reproduce**: diff --git a/README.md b/README.md index febd76f73f8160dd3b1fcb7a4ceeadd8d273b1d4..aff3427bddb307aea6d6c2466eac14c9edffcc32 100644 --- a/README.md +++ b/README.md @@ -38,19 +38,20 @@ People who are a little more adventurous can also try our nightly binaries: **Nightly pip packages** * We are pleased to announce that TensorFlow now offers nightly pip packages -under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) project on pypi. -Simply run `pip install tf-nightly` in a clean environment to install the nightly -tensorflow build. We currently only support CPU packages on Linux, Mac, and Windows. -GPU packages on all platforms will arrive soon! +under the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and +[tf-nightly-gpu](https://pypi.python.org/pypi/tf-nightly-gpu) project on pypi. +Simply run `pip install tf-nightly` or `pip install tf-nightly-gpu` in a clean +environment to install the nightly TensorFlow build. We support CPU and GPU +packages on Linux, Mac, and Windows. **Individual whl files** * Linux CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=cpu-slave/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=cpu-slave/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=cpu-slave/)) * Linux GPU: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/42/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp27-none-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=gpu-linux/)) / [Python 3.4](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp34-cp34m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=gpu-linux/)) / [Python 3.5](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly_gpu-1.head-cp35-cp35m-linux_x86_64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-linux/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3.5,label=gpu-linux/)) * Mac CPU-only: [Python 2](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py2-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON2,label=mac-slave/)) / [Python 3](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/lastSuccessfulBuild/artifact/pip_test/whl/tf_nightly-1.head-py3-none-any.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-mac/TF_BUILD_IS_OPT=OPT,TF_BUILD_IS_PIP=PIP,TF_BUILD_PYTHON_VERSION=PYTHON3,label=mac-slave/)) -* Windows CPU-only: [Python 3.5 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) -* Windows GPU: [Python 3.5 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) -* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](http://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) +* Windows CPU-only: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows,PY=36/)) +* Windows GPU: [Python 3.5 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp35-cp35m-win_amd64.whl) ([build history](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=35/)) / [Python 3.6 64-bit](https://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/lastSuccessfulBuild/artifact/cmake_build/tf_python/dist/tf_nightly_gpu-1.head-cp36-cp36m-win_amd64.whl) ([build history](http://ci.tensorflow.org/view/tf-nightly/job/tf-nightly-windows/M=windows-gpu,PY=36/)) +* Android: [demo APK](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/tensorflow_demo.apk), [native libs](https://ci.tensorflow.org/view/Nightly/job/nightly-android/lastSuccessfulBuild/artifact/out/native/) ([build history](https://ci.tensorflow.org/view/Nightly/job/nightly-android/)) #### *Try your first TensorFlow program* @@ -72,11 +73,11 @@ $ python ## For more information -* [TensorFlow website](https://www.tensorflow.org) +* [TensorFlow Website](https://www.tensorflow.org) * [TensorFlow White Papers](https://www.tensorflow.org/about/bib) * [TensorFlow Model Zoo](https://github.com/tensorflow/models) * [TensorFlow MOOC on Udacity](https://www.udacity.com/course/deep-learning--ud730) -* [TensorFlow course at Stanford](https://web.stanford.edu/class/cs20si) +* [TensorFlow Course at Stanford](https://web.stanford.edu/class/cs20si) Learn more about the TensorFlow community at the [community page of tensorflow.org](https://www.tensorflow.org/community) for a few ways to participate. diff --git a/RELEASE.md b/RELEASE.md index 3d497dbaa965d2cf239cab8360109bf5804b6f6e..d8db1f72004b5d944e3035a0f33dfc34a674b7ee 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,6 +1,59 @@ # Release 1.4.0 ## Major Features And Improvements +* `tf.keras` is now part of the core TensorFlow API. +* [`tf.data`](http://tensorflow.org/programmers_guide/datasets) is now part of + the core TensorFlow API. + * The API is now subject to backwards compatibility guarantees. + * For a guide to migrating from the `tf.contrib.data` API, see the + [README](https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/contrib/data/README.md). + * Major new features include `Dataset.from_generator()` (for building an input + pipeline from a Python generator), and the `Dataset.apply()` method for + applying custom transformation functions. + * Several custom transformation functions have been added, including + `tf.contrib.data.batch_and_drop_remainder()` and + `tf.contrib.data.sloppy_interleave()`. +* Add `train_and_evaluate` for simple distributed `Estimator` training. +* Add `tf.spectral.dct` for computing the DCT-II. +* Add Mel-Frequency Cepstral Coefficient support to `tf.contrib.signal` + (with GPU and gradient support). +* Add a self-check on `import tensorflow` for Windows DLL issues. +* Add NCHW support to `tf.depth_to_space` on GPU. +* TensorFlow Debugger (tfdbg): + * Add `eval` command to allow evaluation of arbitrary Python/numpy expressions + in tfdbg command-line interface. See + [Debugging TensorFlow Programs](https://www.tensorflow.org/programmers_guide/debugger) + for more details. + * Usability improvement: The frequently used tensor filter `has_inf_or_nan` is + now added to `Session` wrappers and hooks by default. So there is no need + for clients to call `.add_tensor_filter(tf_debug.has_inf_or_nan)` anymore. +* SinhArcsinh (scalar) distribution added to `contrib.distributions`. +* Make `GANEstimator` opensource. +* `Estimator.export_savedmodel()` now includes all valid serving signatures + that can be constructed from the Serving Input Receiver and all available + ExportOutputs. For instance, a classifier may provide regression- and + prediction-flavored outputs, in addition to the classification-flavored one. + Building signatures from these allows TF Serving to honor requests using the + different APIs (Classify, Regress, and Predict). Furthermore, + `serving_input_receiver_fn()` may now specify alternative subsets of nodes + that may act as inputs. This allows, for instance, producing a prediction + signature for a classifier that accepts raw `Tensors` instead of a serialized + `tf.Example`. +* Add `tf.contrib.bayesflow.hmc`. +* Add `tf.contrib.distributions.MixtureSameFamily`. +* Make `Dataset.shuffle()` always reshuffles after each iteration by default. +* Add `tf.contrib.bayesflow.metropolis_hastings`. +* Add `log_rate` parameter to `tf.contrib.distributions.Poisson`. +* Extend `tf.contrib.distributions.bijector` API to handle some non-injective + transforms. +* Java: + * Generics (e.g., `Tensor`) for improved type-safety + (courtesy @andrewcmyers). + * Support for multi-dimensional string tensors. + * Support loading of custom operations (e.g. many in `tf.contrib`) on Linux + and OS X +* All our prebuilt binaries have been built with CUDA 8 and cuDNN 6. + We anticipate releasing TensorFlow 1.5 with CUDA 9 and cuDNN 7. ## Bug Fixes and Other Changes * `tf.nn.rnn_cell.DropoutWrapper` is now more careful about dropping out LSTM @@ -12,6 +65,66 @@ * Removed `tf.contrib.training.python_input`. The same behavior, in a more flexible and reproducible package, is available via the new `tf.contrib.data.Dataset.from_generator` method! +* Fix `tf.contrib.distributions.Affine` incorrectly computing log-det-jacobian. +* Fix `tf.random_gamma` incorrectly handling non-batch, scalar draws. +* Resolved a race condition in TensorForest TreePredictionsV4Op. +* Google Cloud Storage file system, Amazon S3 file system, and Hadoop file + system support are now default build options. +* Custom op libraries must link against libtensorflow_framework.so + (installed at `tf.sysconfig.get_lib()`). +* Change `RunConfig` default behavior to not set a random seed, making random + behavior independently random on distributed workers. We expect this to + generally improve training performance. Models that do rely on determinism + should set a random seed explicitly. + +## Breaking Changes to the API +* The signature of the `tf.contrib.data.rejection_resample()` function has been + changed. It now returns a function that can be used as an argument to + `Dataset.apply()`. +* Remove `tf.contrib.data.Iterator.from_dataset()` method. Use + `Dataset.make_initializable_iterator()` instead. +* Remove seldom used and unnecessary `tf.contrib.data.Iterator.dispose_op()`. +* Reorder some TFGAN loss functions in a non-backwards compatible way. + +## Known Issues +* In Python 3, `Dataset.from_generator()` does not support Unicode strings. + You must convert any strings to bytes objects before yielding them from + the generator. + +## Thanks to our Contributors + +This release contains contributions from many people at Google, as well as: + +4d55397500, Abdullah Alrasheed, abenmao, Adam Salvail, Aditya Dhulipala, Ag Ramesh, +Akimasa Kimura, Alan Du, Alan Yee, Alexander, Amit Kushwaha, Amy, Andrei Costinescu, +Andrei Nigmatulin, Andrew Erlichson, Andrew Myers, Andrew Stepanov, Androbin, AngryPowman, +Anish Shah, Anton Daitche, Artsiom Chapialiou, asdf2014, Aseem Raj Baranwal, Ash Hall, +Bart Kiers, Batchu Venkat Vishal, ben, Ben Barsdell, Bill Piel, Carl Thomé, Catalin Voss, +Changming Sun, Chengzhi Chen, Chi Zeng, Chris Antaki, Chris Donahue, Chris Oelmueller, +Chris Tava, Clayne Robison, Codrut, Courtial Florian, Dalmo Cirne, Dan J, Darren Garvey, +David Kristoffersson, David Norman, David RöThlisberger, DavidNorman, Dhruv, DimanNe, +Dorokhov, Duncan Mac-Vicar P, EdwardDixon, EMCP, error.d, FAIJUL, Fan Xia, +Francois Xavier, Fred Reiss, Freedom" Koan-Sin Tan, Fritz Obermeyer, Gao, Xiang, +Guenther Schmuelling, Guo Yejun (郭叶军), Hans Gaiser, HectorSVC, Hyungsuk Yoon, +James Pruegsanusak, Jay Young, Jean Wanka, Jeff Carpenter, Jeremy Rutman, Jeroen BéDorf, +Jett Jones, Jimmy Jia, jinghuangintel, jinze1994, JKurland, Joel Hestness, joetoth, +John B Nelson, John Impallomeni, John Lawson, Jonas, Jonathan Dekhtiar, joshkyh, Jun Luan, +Jun Mei, Kai Sasaki, Karl Lessard, karl@kubx.ca, Kb Sriram, Kenichi Ueno, Kevin Slagle, +Kongsea, Lakshay Garg, lhlmgr, Lin Min, liu.guangcong, Loki Der Quaeler, Louie Helm, +lucasmoura, Luke Iwanski, Lyndon White, Mahmoud Abuzaina, Marcel Puyat, Mark Aaron Shirley, +Michele Colombo, MtDersvan, Namrata-Ibm, Nathan Luehr, Naurril, Nayana Thorat, Nicolas Lopez, +Niranjan Hasabnis, Nolan Liu, Nouce, Oliver Hennigh, osdamv, Patrik Erdes, +Patryk Chrabaszcz, Pavel Christof, Penghao Cen, postBG, Qingqing Cao, Qingying Chen, qjivy, +Raphael, Rasmi, raymondxyang, Renze Yu, resec, Roffel, Ruben Vereecken, Ryohei Kuroki, +sandipmgiri, Santiago Castro, Scott Kirkland, Sean Vig, Sebastian Raschka, Sebastian Weiss, +Sergey Kolesnikov, Sergii Khomenko, Shahid, Shivam Kotwalia, Stuart Berg, Sumit Gouthaman, +superzerg, Sven Mayer, tetris, Ti Zhou, Tiago Freitas Pereira, Tian Jin, Tomoaki Oiki, +Vaibhav Sood, vfdev, Vivek Rane, Vladimir Moskva, wangqr, Weber Xie, Will Frey, +Yan Facai (颜发才), yanivbl6, Yaroslav Bulatov, Yixing Lao, Yong Tang, youkaichao, +Yuan (Terry) Tang, Yue Zhang, Yuxin Wu, Ziming Dong, ZxYuan, 黄璞 + +We are also grateful to all who filed issues or helped resolve them, asked and +answered questions, and were part of inspiring discussions. # Release 1.3.0 diff --git a/WORKSPACE b/WORKSPACE index 32d3d94ec232a5bf0eb0092b5d04df8440127408..b40913801ba8e3c8ee73f7ba69540b520ad698a6 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -2,11 +2,11 @@ workspace(name = "org_tensorflow") http_archive( name = "io_bazel_rules_closure", - sha256 = "25f5399f18d8bf9ce435f85c6bbf671ec4820bc4396b3022cc5dc4bc66303609", - strip_prefix = "rules_closure-0.4.2", + sha256 = "110fe68753413777944b473c25eed6368c4a0487cee23a7bac1b13cc49d3e257", + strip_prefix = "rules_closure-4af89ef1db659eb41f110df189b67d4cf14073e1", urls = [ - "http://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", # 2017-08-29 - "https://github.com/bazelbuild/rules_closure/archive/0.4.2.tar.gz", + "https://mirror.bazel.build/github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", + "https://github.com/bazelbuild/rules_closure/archive/4af89ef1db659eb41f110df189b67d4cf14073e1.tar.gz", # 2017-08-28 ], ) diff --git a/configure.py b/configure.py index df2c74d23d8ea306028c8c0406c5475d31fa884f..6279c4261044e3f33519ece5f2ac19af2acb505d 100644 --- a/configure.py +++ b/configure.py @@ -25,12 +25,15 @@ import re import subprocess import sys +# pylint: disable=g-import-not-at-top try: from shutil import which except ImportError: from distutils.spawn import find_executable as which +# pylint: enable=g-import-not-at-top -_TF_BAZELRC = '.tf_configure.bazelrc' +_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), + '.tf_configure.bazelrc') _DEFAULT_CUDA_VERSION = '8.0' _DEFAULT_CUDNN_VERSION = '6' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' @@ -484,7 +487,10 @@ def set_cc_opt_flags(environ_cp): cc_opt_flags = get_from_env_or_user_or_default(environ_cp, 'CC_OPT_FLAGS', question, default_cc_opt_flags) for opt in cc_opt_flags.split(): - write_to_bazelrc('build:opt --cxxopt=%s --copt=%s' % (opt, opt)) + host_opt = '-march=native' # It should be safe on the same build host. + write_to_bazelrc( + 'build:opt --cxxopt=%s --copt=%s' % (opt, opt) + + ' --host_cxxopt=%s --host_copt=%s' % (host_opt, host_opt)) def set_tf_cuda_clang(environ_cp): @@ -634,7 +640,7 @@ def set_tf_cuda_version(environ_cp): write_action_env_to_bazelrc('TF_CUDA_VERSION', tf_cuda_version) -def set_tf_cunn_version(environ_cp): +def set_tf_cudnn_version(environ_cp): """Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION.""" ask_cudnn_version = ( 'Please specify the cuDNN version you want to use. ' @@ -962,6 +968,19 @@ def set_monolithic(): write_to_bazelrc('build --define framework_shared_object=true') +def create_android_bazelrc_configs(): + # Flags for --config=android + write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool') + write_to_bazelrc( + 'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain') + # Flags for --config=android_arm + write_to_bazelrc('build:android_arm --config=android') + write_to_bazelrc('build:android_arm --cpu=armeabi-v7a') + # Flags for --config=android_arm64 + write_to_bazelrc('build:android_arm64 --config=android') + write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a') + + def main(): # Make a copy of os.environ to be clear when functions and getting and setting # environment variables. @@ -975,10 +994,12 @@ def main(): run_gen_git_source(environ_cp) if is_windows(): + environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' + environ_cp['TF_NEED_S3'] = '0' environ_cp['TF_CUDA_CLANG'] = '0' if is_macos(): @@ -987,9 +1008,11 @@ def main(): set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc', 'with_jemalloc', True) set_build_var(environ_cp, 'TF_NEED_GCP', 'Google Cloud Platform', - 'with_gcp_support', False, 'gcp') + 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', - 'with_hdfs_support', False, 'hdfs') + 'with_hdfs_support', True, 'hdfs') + set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', + 'with_s3_support', True, 's3') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1007,7 +1030,7 @@ def main(): if (environ_cp.get('TF_NEED_CUDA') == '1' and 'TF_CUDA_CONFIG_REPO' not in environ_cp): set_tf_cuda_version(environ_cp) - set_tf_cunn_version(environ_cp) + set_tf_cudnn_version(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) set_tf_cuda_clang(environ_cp) @@ -1029,7 +1052,7 @@ def main(): set_cc_opt_flags(environ_cp) set_mkl() set_monolithic() - + create_android_bazelrc_configs() if __name__ == '__main__': main() diff --git a/tensorflow/BUILD b/tensorflow/BUILD index c389183191ca33fc92a01ece2ecf25e66ed6f776..9874f95ea3268dfce0158d3ddcdefea77136cad8 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -54,6 +54,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "raspberry_pi_armeabi", + values = { + "crosstool_top": "@local_config_arm_compiler//:toolchain", + "cpu": "armeabi", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "android_arm", values = { @@ -120,6 +129,15 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "ios_x86_64", + values = { + "crosstool_top": "//tools/osx/crosstool:crosstool", + "cpu": "ios_x86_64", + }, + visibility = ["//visibility:public"], +) + config_setting( name = "linux_x86_64", values = {"cpu": "k8"}, @@ -185,6 +203,12 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "with_s3_support", + values = {"define": "with_s3_support=true"}, + visibility = ["//visibility:public"], +) + config_setting( name = "with_xla_support", values = {"define": "with_xla_support=true"}, @@ -273,11 +297,19 @@ config_setting( visibility = ["//visibility:public"], ) +# Make a dummy rule that we can chaqnge "default" in select statements to. +# to disable dependencies in copybara. +config_setting( + name = "dummy_disabled_internal", + values = {"define": "with_dummy_disabled_internal=true"}, + visibility = ["//visibility:public"], +) + package_group( name = "internal", packages = [ - "//learning/protonn/llgtm/...", "//tensorflow/...", + "//tensorflow_fold/llgtm/...", ], ) @@ -316,6 +348,7 @@ filegroup( "//tensorflow/compiler/jit/kernels:all_files", "//tensorflow/compiler/jit/legacy_flags:all_files", "//tensorflow/compiler/jit/ops:all_files", + "//tensorflow/compiler/plugin:all_files", "//tensorflow/compiler/tests:all_files", "//tensorflow/compiler/tf2xla:all_files", "//tensorflow/compiler/tf2xla/cc:all_files", @@ -333,6 +366,7 @@ filegroup( "//tensorflow/compiler/xla/service/llvm_ir:all_files", "//tensorflow/compiler/xla/tests:all_files", "//tensorflow/compiler/xla/tools:all_files", + "//tensorflow/compiler/xla/tools/parser:all_files", "//tensorflow/contrib:all_files", "//tensorflow/contrib/all_reduce:all_files", "//tensorflow/contrib/android:all_files", @@ -354,6 +388,7 @@ filegroup( "//tensorflow/contrib/crf:all_files", "//tensorflow/contrib/cudnn_rnn:all_files", "//tensorflow/contrib/data:all_files", + "//tensorflow/contrib/data/kernels:all_files", "//tensorflow/contrib/data/python/kernel_tests:all_files", "//tensorflow/contrib/data/python/ops:all_files", "//tensorflow/contrib/decision_trees/proto:all_files", @@ -378,6 +413,11 @@ filegroup( "//tensorflow/contrib/integrate:all_files", "//tensorflow/contrib/keras:all_files", "//tensorflow/contrib/kernel_methods:all_files", + "//tensorflow/contrib/kfac:all_files", + "//tensorflow/contrib/kfac/examples:all_files", + "//tensorflow/contrib/kfac/examples/tests:all_files", + "//tensorflow/contrib/kfac/python/kernel_tests:all_files", + "//tensorflow/contrib/kfac/python/ops:all_files", "//tensorflow/contrib/labeled_tensor:all_files", "//tensorflow/contrib/layers:all_files", "//tensorflow/contrib/layers/kernels:all_files", @@ -387,20 +427,22 @@ filegroup( "//tensorflow/contrib/linear_optimizer:all_files", "//tensorflow/contrib/lookup:all_files", "//tensorflow/contrib/losses:all_files", + "//tensorflow/contrib/makefile:all_files", "//tensorflow/contrib/meta_graph_transform:all_files", "//tensorflow/contrib/metrics:all_files", + "//tensorflow/contrib/model_pruning:all_files", "//tensorflow/contrib/mpi_collectives:all_files", "//tensorflow/contrib/ndlstm:all_files", "//tensorflow/contrib/nearest_neighbor:all_files", "//tensorflow/contrib/nn:all_files", "//tensorflow/contrib/opt:all_files", "//tensorflow/contrib/predictor:all_files", + "//tensorflow/contrib/quantize:all_files", "//tensorflow/contrib/receptive_field:all_files", "//tensorflow/contrib/reduce_slice_ops:all_files", "//tensorflow/contrib/remote_fused_graph/pylib:all_files", "//tensorflow/contrib/resampler:all_files", "//tensorflow/contrib/rnn:all_files", - "//tensorflow/contrib/s3:all_files", "//tensorflow/contrib/saved_model:all_files", "//tensorflow/contrib/saved_model/cc/saved_model:all_files", "//tensorflow/contrib/seq2seq:all_files", @@ -422,6 +464,7 @@ filegroup( "//tensorflow/contrib/tensor_forest/kernels/v4:all_files", "//tensorflow/contrib/tensor_forest/proto:all_files", "//tensorflow/contrib/tensorboard:all_files", + "//tensorflow/contrib/tensorboard/db:all_files", "//tensorflow/contrib/testing:all_files", "//tensorflow/contrib/text:all_files", "//tensorflow/contrib/tfprof:all_files", @@ -434,7 +477,6 @@ filegroup( "//tensorflow/contrib/training:all_files", "//tensorflow/contrib/util:all_files", "//tensorflow/contrib/verbs:all_files", - "//tensorflow/contrib/xla_tf_graph:all_files", "//tensorflow/core:all_files", "//tensorflow/core/debug:all_files", "//tensorflow/core/distributed_runtime:all_files", @@ -449,10 +491,12 @@ filegroup( "//tensorflow/core/kernels/fuzzing:all_files", "//tensorflow/core/kernels/hexagon:all_files", "//tensorflow/core/kernels/neon:all_files", + "//tensorflow/core/lib/db:all_files", "//tensorflow/core/ops/compat:all_files", "//tensorflow/core/platform/cloud:all_files", "//tensorflow/core/platform/default/build_config:all_files", "//tensorflow/core/platform/hadoop:all_files", + "//tensorflow/core/platform/s3:all_files", "//tensorflow/core/profiler:all_files", "//tensorflow/core/profiler/internal:all_files", "//tensorflow/core/profiler/internal/advisor:all_files", @@ -486,7 +530,10 @@ filegroup( "//tensorflow/python/keras:all_files", "//tensorflow/python/kernel_tests:all_files", "//tensorflow/python/kernel_tests/distributions:all_files", + "//tensorflow/python/kernel_tests/linalg:all_files", + "//tensorflow/python/kernel_tests/random:all_files", "//tensorflow/python/ops/distributions:all_files", + "//tensorflow/python/ops/linalg:all_files", "//tensorflow/python/profiler:all_files", "//tensorflow/python/profiler/internal:all_files", "//tensorflow/python/saved_model:all_files", @@ -495,6 +542,7 @@ filegroup( "//tensorflow/tools/api/golden:all_files", "//tensorflow/tools/api/lib:all_files", "//tensorflow/tools/api/tests:all_files", + "//tensorflow/tools/benchmark:all_files", "//tensorflow/tools/build_info:all_files", "//tensorflow/tools/common:all_files", "//tensorflow/tools/compatibility:all_files", diff --git a/tensorflow/c/BUILD b/tensorflow/c/BUILD index aead7154ee68226ae8128c3b2a96c077657c5e5c..ef7eb5a4d16b29aecc34f33cb41dd7cf9450c5f2 100644 --- a/tensorflow/c/BUILD +++ b/tensorflow/c/BUILD @@ -30,7 +30,10 @@ tf_cuda_library( name = "c_api_internal", srcs = ["c_api.h"], hdrs = ["c_api_internal.h"], - visibility = ["//tensorflow/c:__subpackages__"], + visibility = [ + "//tensorflow:internal", + "//tensorflow/c:__subpackages__", + ], deps = select({ "//tensorflow:android": [ "//tensorflow/core:android_tensorflow_lib_lite", @@ -72,6 +75,7 @@ tf_cuda_library( "//tensorflow/core:framework", "//tensorflow/core:protos_all_cc", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", ], }), ) @@ -168,7 +172,6 @@ tf_cc_test( srcs = ["c_api_function_test.cc"], deps = [ ":c_api", - ":c_api_internal", ":c_test_util", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index 334f867e47800507760eaa71dce91186f646f72d..6dd1b999102d0135720b6ab3a43cbe61255acbc1 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -81,11 +81,13 @@ using tensorflow::TensorBuffer; using tensorflow::TensorId; using tensorflow::TensorShape; using tensorflow::TensorShapeProto; +using tensorflow::VersionDef; using tensorflow::error::Code; using tensorflow::errors::FailedPrecondition; using tensorflow::errors::InvalidArgument; using tensorflow::gtl::ArraySlice; using tensorflow::mutex_lock; +using tensorflow::string; using tensorflow::strings::StrCat; extern "C" { @@ -366,7 +368,7 @@ namespace { // Reset helper for converting character arrays to string vectors. void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers, int ncontainers, TF_Status* status) { - std::vector container_names(ncontainers); + std::vector container_names(ncontainers); for (int i = 0; i < ncontainers; ++i) { container_names[i] = containers[i]; } @@ -482,7 +484,7 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst) { const char* limit = input + src_size; *dst = Tensor(static_cast(src->dtype), src->shape); - auto dstarray = dst->flat(); + auto dstarray = dst->flat(); for (tensorflow::int64 i = 0; i < num_elements; ++i) { tensorflow::uint64 offset = reinterpret_cast(input)[i]; @@ -556,9 +558,9 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, // Compute bytes needed for encoding. size_t size = 0; - const auto& srcarray = src.flat(); + const auto& srcarray = src.flat(); for (int i = 0; i < srcarray.size(); ++i) { - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); // uint64 starting_offset, TF_StringEncode-d string. size += sizeof(tensorflow::uint64) + TF_StringEncodedSize(s.size()); } @@ -572,7 +574,7 @@ TF_Tensor* TF_TensorFromTensor(const tensorflow::Tensor& src, for (int i = 0; i < srcarray.size(); ++i) { *offsets = (dst - data_start); offsets++; - const tensorflow::string& s = srcarray(i); + const string& s = srcarray(i); size_t consumed = TF_StringEncode(s.data(), s.size(), dst, dst_len, status); if (!status->status.ok()) { status->status = InvalidArgument( @@ -637,10 +639,9 @@ static void TF_Run_Setup(int noutputs, TF_Tensor** c_outputs, } } -static bool TF_Run_Inputs( - TF_Tensor* const* c_inputs, - std::vector>* input_pairs, - TF_Status* status) { +static bool TF_Run_Inputs(TF_Tensor* const* c_inputs, + std::vector>* input_pairs, + TF_Status* status) { const int ninputs = input_pairs->size(); for (int i = 0; i < ninputs; ++i) { status->status = TF_TensorToTensor(c_inputs[i], &(*input_pairs)[i].second); @@ -652,13 +653,12 @@ static bool TF_Run_Inputs( static void TF_Run_Helper( Session* session, const char* handle, const TF_Buffer* run_options, // Input tensors - const std::vector>& input_pairs, + const std::vector>& input_pairs, // Output tensors - const std::vector& output_tensor_names, - TF_Tensor** c_outputs, + const std::vector& output_tensor_names, TF_Tensor** c_outputs, // Target nodes - const std::vector& target_oper_names, - TF_Buffer* run_metadata, TF_Status* status) { + const std::vector& target_oper_names, TF_Buffer* run_metadata, + TF_Status* status) { const int noutputs = output_tensor_names.size(); std::vector outputs(noutputs); Status result; @@ -718,16 +718,16 @@ void TF_Run(TF_DeprecatedSession* s, const TF_Buffer* run_options, const char** c_target_oper_names, int ntargets, TF_Buffer* run_metadata, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector target_oper_names(ntargets); + std::vector target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -745,9 +745,9 @@ void TF_PRunSetup(TF_DeprecatedSession* s, const char** handle, TF_Status* status) { *handle = nullptr; - std::vector input_names(ninputs); - std::vector output_names(noutputs); - std::vector target_oper_names(ntargets); + std::vector input_names(ninputs); + std::vector output_names(noutputs); + std::vector target_oper_names(ntargets); for (int i = 0; i < ninputs; ++i) { input_names[i] = c_input_names[i]; } @@ -757,7 +757,7 @@ void TF_PRunSetup(TF_DeprecatedSession* s, for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } - tensorflow::string new_handle; + string new_handle; status->status = s->session->PRunSetup(input_names, output_names, target_oper_names, &new_handle); if (status->status.ok()) { @@ -776,17 +776,17 @@ void TF_PRun(TF_DeprecatedSession* s, const char* handle, const char** c_target_oper_names, int ntargets, TF_Status* status) { TF_Run_Setup(noutputs, c_outputs, status); - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(c_inputs, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = c_input_names[i]; } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = c_output_names[i]; } - std::vector target_oper_names(ntargets); + std::vector target_oper_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_oper_names[i] = c_target_oper_names[i]; } @@ -881,7 +881,7 @@ TF_Operation* ToOperation(Node* node) { return static_cast(static_cast(node)); } -tensorflow::string OutputName(const TF_Output& output) { +string OutputName(const TF_Output& output) { return StrCat(output.oper->node.name(), ":", output.index); } @@ -1254,7 +1254,7 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, return; } desc->colocation_constraints.clear(); - for (const tensorflow::string& location : attr_value.list().s()) { + for (const string& location : attr_value.list().s()) { desc->colocation_constraints.insert(location); } } else { @@ -1276,8 +1276,8 @@ static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, if (!desc->colocation_constraints.empty()) { desc->node_builder.Attr( tensorflow::kColocationAttrName, - std::vector(desc->colocation_constraints.begin(), - desc->colocation_constraints.end())); + std::vector(desc->colocation_constraints.begin(), + desc->colocation_constraints.end())); } status->status = desc->node_builder.Finalize(&desc->graph->graph, &ret); @@ -1500,7 +1500,7 @@ TF_AttrMetadata TF_OperationGetAttrMetadata(TF_Operation* oper, for (int i = 0; i < oper->node.op_def().attr_size(); ++i) { const auto& a = oper->node.op_def().attr(i); if (a.name().compare(attr_name) != 0) continue; - const tensorflow::string& typestr = a.type(); + const string& typestr = a.type(); if (typestr == "list(string)") { metadata.type = TF_ATTR_STRING; } else if (typestr == "list(int)") { @@ -1580,7 +1580,7 @@ void TF_OperationGetAttrStringList(TF_Operation* oper, const char* attr_name, const auto len = std::min(max_values, attr->list().s_size()); char* p = static_cast(storage); for (int i = 0; i < len; ++i) { - const tensorflow::string& s = attr->list().s(i); + const string& s = attr->list().s(i); values[i] = p; lengths[i] = s.size(); if ((p + s.size()) > (static_cast(storage) + storage_size)) { @@ -1799,6 +1799,27 @@ void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, status->status = MessageToBuffer(def, output_graph_def); } +void TF_GraphGetOpDef(TF_Graph* graph, const char* op_name, + TF_Buffer* output_op_def, TF_Status* status) { + const OpDef* op_def; + { + mutex_lock l(graph->mu); + status->status = graph->graph.op_registry()->LookUpOpDef(op_name, &op_def); + if (!status->status.ok()) return; + } + status->status = MessageToBuffer(*op_def, output_op_def); +} + +void TF_GraphVersions(TF_Graph* graph, TF_Buffer* output_version_def, + TF_Status* status) { + VersionDef versions; + { + mutex_lock l(graph->mu); + versions = graph->graph.versions(); + } + status->status = MessageToBuffer(versions, output_version_def); +} + TF_ImportGraphDefOptions* TF_NewImportGraphDefOptions() { return new TF_ImportGraphDefOptions; } @@ -1813,7 +1834,11 @@ void TF_ImportGraphDefOptionsSetPrefix(TF_ImportGraphDefOptions* opts, void TF_ImportGraphDefOptionsAddInputMapping(TF_ImportGraphDefOptions* opts, const char* src_name, int src_index, TF_Output dst) { - opts->opts.input_map[TensorId(src_name, src_index)] = ToTensorId(dst); + opts->tensor_id_data.push_back(src_name); + const string& src_name_str = opts->tensor_id_data.back(); + // We don't need to store dst's name in tensor_id_data, since `dst` must + // outlive the ImportGraphDef call. + opts->opts.input_map[TensorId(src_name_str, src_index)] = ToTensorId(dst); } void TF_ImportGraphDefOptionsRemapControlDependency( @@ -1829,7 +1854,9 @@ extern void TF_ImportGraphDefOptionsAddControlDependency( void TF_ImportGraphDefOptionsAddReturnOutput(TF_ImportGraphDefOptions* opts, const char* oper_name, int index) { - opts->opts.return_tensors.push_back({oper_name, index}); + opts->tensor_id_data.push_back(oper_name); + const string& oper_name_str = opts->tensor_id_data.back(); + opts->opts.return_tensors.emplace_back(oper_name_str, index); } int TF_ImportGraphDefOptionsNumReturnOutputs( @@ -1837,57 +1864,142 @@ int TF_ImportGraphDefOptionsNumReturnOutputs( return opts->opts.return_tensors.size(); } +void TF_ImportGraphDefOptionsAddReturnOperation(TF_ImportGraphDefOptions* opts, + const char* oper_name) { + opts->opts.return_nodes.push_back(oper_name); +} + +int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts) { + return opts->opts.return_nodes.size(); +} + +void TF_ImportGraphDefResultsReturnOutputs(TF_ImportGraphDefResults* results, + int* num_outputs, + TF_Output** outputs) { + *num_outputs = results->return_tensors.size(); + *outputs = results->return_tensors.data(); +} + +void TF_ImportGraphDefResultsReturnOperations(TF_ImportGraphDefResults* results, + int* num_opers, + TF_Operation*** opers) { + *num_opers = results->return_nodes.size(); + *opers = results->return_nodes.data(); +} + +void TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_unused_input_mappings, + const char*** src_names, int** src_indexes) { + *num_unused_input_mappings = results->unused_key_names.size(); + *src_names = results->unused_key_names.data(); + *src_indexes = results->unused_key_indexes.data(); +} + +void TF_DeleteImportGraphDefResults(TF_ImportGraphDefResults* results) { + delete results; +} + static void GraphImportGraphDefLocked(TF_Graph* graph, const GraphDef& def, const TF_ImportGraphDefOptions* opts, - TF_Output* return_outputs, - int num_return_outputs, TF_Status* status) + TF_ImportGraphDefResults* tf_results, + TF_Status* status) EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { - if (num_return_outputs != opts->opts.return_tensors.size()) { - status->status = InvalidArgument("Expected 'num_return_outputs' to be ", - opts->opts.return_tensors.size(), ", got ", - num_return_outputs); - return; - } - if (num_return_outputs > 0 && return_outputs == nullptr) { - status->status = InvalidArgument( - "'return_outputs' must be preallocated to length ", num_return_outputs); - return; - } const int last_node_id = graph->graph.num_node_ids(); - std::vector> return_outputs_vec; - status->status = tensorflow::ImportGraphDef( - opts->opts, def, &graph->graph, &graph->refiner, &return_outputs_vec); + tensorflow::ImportGraphDefResults results; + status->status = tensorflow::ImportGraphDef(opts->opts, def, &graph->graph, + &graph->refiner, &results); if (!status->status.ok()) return; + + // Add new nodes to name_map for (int i = last_node_id; i < graph->graph.num_node_ids(); ++i) { auto* node = graph->graph.FindNodeId(i); if (node != nullptr) graph->name_map[node->name()] = node; } - DCHECK_EQ(return_outputs_vec.size(), num_return_outputs); - for (int i = 0; i < num_return_outputs; ++i) { - return_outputs[i].oper = ToOperation(return_outputs_vec[i].first); - return_outputs[i].index = return_outputs_vec[i].second; + + // Populate return_tensors + DCHECK(tf_results->return_tensors.empty()); + tf_results->return_tensors.resize(results.return_tensors.size()); + for (int i = 0; i < results.return_tensors.size(); ++i) { + tf_results->return_tensors[i].oper = + ToOperation(results.return_tensors[i].first); + tf_results->return_tensors[i].index = results.return_tensors[i].second; + } + + // Populate return_nodes + DCHECK(tf_results->return_nodes.empty()); + tf_results->return_nodes.resize(results.return_nodes.size()); + for (int i = 0; i < results.return_nodes.size(); ++i) { + tf_results->return_nodes[i] = ToOperation(results.return_nodes[i]); + } + + // Populate unused map keys + DCHECK(tf_results->unused_key_names.empty()); + DCHECK(tf_results->unused_key_indexes.empty()); + DCHECK(tf_results->unused_key_names_data.empty()); + tf_results->unused_key_names.resize(results.unused_input_map_keys.size()); + tf_results->unused_key_indexes.resize(results.unused_input_map_keys.size()); + for (int i = 0; i < results.unused_input_map_keys.size(); ++i) { + TensorId id = results.unused_input_map_keys[i]; + tf_results->unused_key_names_data.push_back(id.first.ToString()); + tf_results->unused_key_names[i] = + tf_results->unused_key_names_data.back().c_str(); + tf_results->unused_key_indexes[i] = id.second; + } +} + +TF_ImportGraphDefResults* TF_GraphImportGraphDefWithResults( + TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, TF_Status* status) { + GraphDef def; + if (!def.ParseFromArray(graph_def->data, graph_def->length)) { + status->status = InvalidArgument("Invalid GraphDef"); + return nullptr; + } + auto results = new TF_ImportGraphDefResults(); + mutex_lock l(graph->mu); + GraphImportGraphDefLocked(graph, def, options, results, status); + if (!status->status.ok()) { + delete results; + return nullptr; } + return results; } void TF_GraphImportGraphDefWithReturnOutputs( TF_Graph* graph, const TF_Buffer* graph_def, - const TF_ImportGraphDefOptions* opts, TF_Output* return_outputs, + const TF_ImportGraphDefOptions* options, TF_Output* return_outputs, int num_return_outputs, TF_Status* status) { + if (num_return_outputs != options->opts.return_tensors.size()) { + status->status = InvalidArgument("Expected 'num_return_outputs' to be ", + options->opts.return_tensors.size(), + ", got ", num_return_outputs); + return; + } + if (num_return_outputs > 0 && return_outputs == nullptr) { + status->status = InvalidArgument( + "'return_outputs' must be preallocated to length ", num_return_outputs); + return; + } GraphDef def; if (!def.ParseFromArray(graph_def->data, graph_def->length)) { status->status = InvalidArgument("Invalid GraphDef"); return; } + TF_ImportGraphDefResults results; mutex_lock l(graph->mu); - GraphImportGraphDefLocked(graph, def, opts, return_outputs, - num_return_outputs, status); + GraphImportGraphDefLocked(graph, def, options, &results, status); + DCHECK_EQ(results.return_tensors.size(), num_return_outputs); + memcpy(return_outputs, results.return_tensors.data(), + num_return_outputs * sizeof(TF_Output)); } void TF_GraphImportGraphDef(TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status) { - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, options, nullptr, 0, - status); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, options, status); + TF_DeleteImportGraphDefResults(results); } // While loop functions ------------------------------------------------------- @@ -1919,7 +2031,7 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, tensorflow::ShapeRefiner* dst_refiner, const TF_Output* src_inputs, const std::vector& dst_inputs, - const tensorflow::string& prefix, + const string& prefix, const std::vector& control_deps, const TF_Output* nodes_to_return, int nreturn_nodes, std::vector* return_nodes) { @@ -1945,11 +2057,11 @@ Status CopyGraph(Graph* src_graph, Graph* dst_graph, } // TOOD(skyewm): change to OutputTensor - std::vector> return_tensors; + tensorflow::ImportGraphDefResults results; TF_RETURN_IF_ERROR( - ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &return_tensors)); + ImportGraphDef(opts, gdef, dst_graph, dst_refiner, &results)); - for (const auto& pair : return_tensors) { + for (const auto& pair : results.return_tensors) { return_nodes->emplace_back(pair.first, pair.second); } return Status::OK(); @@ -2246,9 +2358,9 @@ TF_Session* TF_LoadSessionFromSavedModel( return nullptr; } - std::unordered_set tag_set; + std::unordered_set tag_set; for (int i = 0; i < tags_len; i++) { - tag_set.insert(tensorflow::string(tags[i])); + tag_set.insert(string(tags[i])); } tensorflow::SavedModelBundle bundle; @@ -2264,8 +2376,9 @@ TF_Session* TF_LoadSessionFromSavedModel( // TODO(jhseu): When Session is modified to take Graphs instead of // GraphDefs, return the Graph generated in LoadSavedModel(). TF_ImportGraphDefOptions* import_opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefResults results; GraphImportGraphDefLocked(graph, bundle.meta_graph_def.graph_def(), - import_opts, nullptr, 0, status); + import_opts, &results, status); TF_DeleteImportGraphDefOptions(import_opts); if (TF_GetCode(status) != TF_OK) return nullptr; @@ -2361,20 +2474,20 @@ void TF_SessionRun(TF_Session* session, const TF_Buffer* run_options, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } @@ -2395,22 +2508,22 @@ void TF_SessionPRunSetup(TF_Session* session, const TF_Output* inputs, return; } - std::vector input_names(ninputs); + std::vector input_names(ninputs); for (int i = 0; i < ninputs; ++i) { input_names[i] = OutputName(inputs[i]); } - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } - tensorflow::string new_handle; + string new_handle; status->status = session->session->PRunSetup(input_names, output_names, target_names, &new_handle); if (status->status.ok()) { @@ -2441,20 +2554,20 @@ void TF_SessionPRun(TF_Session* session, const char* handle, TF_Run_Setup(noutputs, output_values, status); // Convert from TF_Output and TF_Tensor to a string and Tensor. - std::vector> input_pairs(ninputs); + std::vector> input_pairs(ninputs); if (!TF_Run_Inputs(input_values, &input_pairs, status)) return; for (int i = 0; i < ninputs; ++i) { input_pairs[i].first = OutputName(inputs[i]); } // Convert from TF_Output to string names. - std::vector output_names(noutputs); + std::vector output_names(noutputs); for (int i = 0; i < noutputs; ++i) { output_names[i] = OutputName(outputs[i]); } // Convert from TF_Operation* to string names. - std::vector target_names(ntargets); + std::vector target_names(ntargets); for (int i = 0; i < ntargets; ++i) { target_names[i] = target_opers[i]->node.name(); } diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h index a17c877804aa8548606261cb8c3f84dc6db45a58..bb569d67fcbcec29e9494236abd79b3e40db91cd 100644 --- a/tensorflow/c/c_api.h +++ b/tensorflow/c/c_api.h @@ -118,6 +118,8 @@ typedef enum TF_DataType { TF_HALF = 19, TF_RESOURCE = 20, TF_VARIANT = 21, + TF_UINT32 = 22, + TF_UINT64 = 23, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding @@ -862,6 +864,18 @@ TF_CAPI_EXPORT extern void TF_GraphToGraphDef(TF_Graph* graph, TF_Buffer* output_graph_def, TF_Status* status); +// Returns the serialized OpDef proto with name `op_name`, or a bad status if no +// such op exists. This can return OpDefs of functions copied into the graph. +TF_CAPI_EXPORT extern void TF_GraphGetOpDef(TF_Graph* graph, + const char* op_name, + TF_Buffer* output_op_def, + TF_Status* status); + +// Returns the serialized VersionDef proto for this graph. +TF_CAPI_EXPORT extern void TF_GraphVersions(TF_Graph* graph, + TF_Buffer* output_version_def, + TF_Status* status); + // TF_ImportGraphDefOptions holds options that can be passed to // TF_GraphImportGraphDef. typedef struct TF_ImportGraphDefOptions TF_ImportGraphDefOptions; @@ -905,7 +919,62 @@ TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOutput( TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOutputs( const TF_ImportGraphDefOptions* opts); +// Add an operation in `graph_def` to be returned via the `return_opers` output +// parameter of TF_GraphImportGraphDef(). +TF_CAPI_EXPORT extern void TF_ImportGraphDefOptionsAddReturnOperation( + TF_ImportGraphDefOptions* opts, const char* oper_name); + +// Returns the number of return operations added via +// TF_ImportGraphDefOptionsAddReturnOperation(). +TF_CAPI_EXPORT extern int TF_ImportGraphDefOptionsNumReturnOperations( + const TF_ImportGraphDefOptions* opts); + +// TF_ImportGraphDefResults holds results that are generated by +// TF_GraphImportGraphDefWithResults(). +typedef struct TF_ImportGraphDefResults TF_ImportGraphDefResults; + +// Fetches the return outputs requested via +// TF_ImportGraphDefOptionsAddReturnOutput(). The number of fetched outputs is +// returned in `num_outputs`. The array of return outputs is returned in +// `outputs`. `*outputs` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOutputs( + TF_ImportGraphDefResults* results, int* num_outputs, TF_Output** outputs); + +// Fetches the return operations requested via +// TF_ImportGraphDefOptionsAddReturnOperation(). The number of fetched +// operations is returned in `num_opers`. The array of return operations is +// returned in `opers`. `*opers` is owned by and has the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsReturnOperations( + TF_ImportGraphDefResults* results, int* num_opers, TF_Operation*** opers); + +// Fetches any input mappings requested via +// TF_ImportGraphDefOptionsAddInputMapping() that weren't used as input to any +// node in the imported graph def. The number of fetched mappings is returned in +// `num_unused_input_mappings`. The array of each mapping's source node name is +// returned in `src_names`, and the array of each mapping's source index is +// returned in `src_indexes`. +// +// `*src_names`, `*src_indexes`, and the memory backing each string in +// `src_names` are owned by and have the lifetime of `results`. +TF_CAPI_EXPORT extern void TF_ImportGraphDefResultsUnusedInputMappings( + TF_ImportGraphDefResults* results, int* num_unused_input_mappings, + const char*** src_names, int** src_indexes); + +// Deletes a results object returned by TF_GraphImportGraphDefWithResults(). +TF_CAPI_EXPORT extern void TF_DeleteImportGraphDefResults( + TF_ImportGraphDefResults* results); + +// Import the graph serialized in `graph_def` into `graph`. Returns nullptr and +// a bad status on error. Otherwise, returns a populated +// TF_ImportGraphDefResults instance. The returned instance must be deleted via +// TF_DeleteImportGraphDefResults(). +TF_CAPI_EXPORT extern TF_ImportGraphDefResults* +TF_GraphImportGraphDefWithResults(TF_Graph* graph, const TF_Buffer* graph_def, + const TF_ImportGraphDefOptions* options, + TF_Status* status); + // Import the graph serialized in `graph_def` into `graph`. +// Convenience function for when only return outputs are needed. // // `num_return_outputs` must be the number of return outputs added (i.e. the // result of TF_ImportGraphDefOptionsNumReturnOutputs()). If @@ -917,7 +986,7 @@ TF_CAPI_EXPORT extern void TF_GraphImportGraphDefWithReturnOutputs( int num_return_outputs, TF_Status* status); // Import the graph serialized in `graph_def` into `graph`. -// Convenience function for when no return outputs have been added. +// Convenience function for when no results are needed. TF_CAPI_EXPORT extern void TF_GraphImportGraphDef( TF_Graph* graph, const TF_Buffer* graph_def, const TF_ImportGraphDefOptions* options, TF_Status* status); @@ -1039,12 +1108,14 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, // fn_body - the graph whose operations (or subset of whose operations) will be // converted to TF_Function. // fn_name - the name of the new TF_Function. Should match the operation -// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]* and be distinct -// from other operation names (at least those registered in graphs -// where this function will be used). -// TODO(iga): Allow null in here and have C API come up with -// a unique name with high probability (similarly to -// _create_hash_str in function.py) +// name (OpDef.name) regexp [A-Z][A-Za-z0-9_.\\-/]*. +// If `append_hash_to_fn_name` is false, `fn_name` must be distinct +// from other function and operation names (at least those +// registered in graphs where this function will be used). +// append_hash_to_fn_name - Must be 0 or 1. If set to 1, the actual name +// of the function will be `fn_name` appended with +// '_'. +// If set to 0, the function's name will be `fn_name`. // num_opers - `num_opers` contains the number of elements in the `opers` array // or a special value of -1 meaning that no array is given. // The distinction between an empty array of operations and no @@ -1114,7 +1185,8 @@ TF_CAPI_EXPORT void TF_AddGradients(TF_Graph* g, TF_Output* y, int ny, // // On failure, null. TF_CAPI_EXPORT extern TF_Function* TF_GraphToFunction( - const TF_Graph* fn_body, const char* fn_name, int num_opers, + const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, int noutputs, const TF_Output* outputs, const char* const* output_names, const TF_FunctionOptions* opts, const char* description, TF_Status* status); @@ -1129,17 +1201,19 @@ TF_CAPI_EXPORT extern void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, TF_Status* status); -// Construct and return the function serialized in `func_def`. +// Construct and return the function whose FunctionDef representation is +// serialized in `proto`. `proto_len` must equal the number of bytes +// pointed to by `proto`. // Returns: // On success, a newly created TF_Function instance. It must be deleted by // calling TF_DeleteFunction. // // On failure, null. TF_CAPI_EXPORT extern TF_Function* TF_FunctionImportFunctionDef( - const TF_Buffer* func_def, TF_Status* status); + const void* proto, size_t proto_len, TF_Status* status); // Sets function attribute named `attr_name` to value stored in `proto`. -// If this attribute is already set to another value, it is overriden. +// If this attribute is already set to another value, it is overridden. // `proto` should point to a sequence of bytes of length `proto_len` // representing a binary serialization of an AttrValue protocol // buffer. diff --git a/tensorflow/c/c_api_function.cc b/tensorflow/c/c_api_function.cc index 61484fd8ea0e2a794798ca98bcafb1352bd669f9..dcb818b88b6fca460852beb6e948d2eb6964f663 100644 --- a/tensorflow/c/c_api_function.cc +++ b/tensorflow/c/c_api_function.cc @@ -25,6 +25,7 @@ limitations under the License. #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/strings/base64.h" #include "tensorflow/core/lib/strings/strcat.h" using tensorflow::errors::InvalidArgument; @@ -232,6 +233,7 @@ Status FillFunctionBody( // Graph to FunctionDef conversion. This code is closely modeled on the Python // code in third_party/tensorflow/python/framework/function.py. Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, + bool append_hash_to_fn_name, const std::vector& body_nodes, const std::vector& inputs, const std::vector& outputs, @@ -241,7 +243,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, DCHECK_EQ(output_names.size(), outputs.size()); } - fdef->mutable_signature()->set_name(fn_name); if (description != nullptr) { fdef->mutable_signature()->set_description(description); } @@ -328,7 +329,6 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, // Remap return values. for (int r = 0; r < fdef->signature().output_arg_size(); ++r) { const string& ret_name = fdef->signature().output_arg(r).name(); - // We convert this flat tensor name to the nested value // (e.g. `add:z:1`) that we stored in tensor_renaming. const string& return_value = @@ -343,6 +343,24 @@ Status GraphToFunctionDef(const Graph& fn_body, const string& fn_name, (*fdef->mutable_ret())[ret_name] = iter->second; } + if (append_hash_to_fn_name) { + const uint64 hash = FunctionDefHash(*fdef); + string encoded; + TF_RETURN_IF_ERROR(Base64Encode( + StringPiece(reinterpret_cast(&hash), sizeof(hash)), + &encoded)); + // Besides letters and digits our Base64 encoding uses '_' and '-'. + // Dash is invalid in operation names and multiple underscores in random + // places look strange. Since we never need to decode the hash back, + // replace these chars with with 'a' and 'A'. Replacing with different + // letters keeps more entropy. + std::replace(encoded.begin(), encoded.end(), '-', 'a'); + std::replace(encoded.begin(), encoded.end(), '_', 'A'); + fdef->mutable_signature()->set_name(strings::StrCat(fn_name, "_", encoded)); + } else { + fdef->mutable_signature()->set_name(fn_name); + } + return Status::OK(); } @@ -451,6 +469,7 @@ using tensorflow::Node; using tensorflow::string; TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, + unsigned char append_hash_to_fn_name, int num_opers, const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, int noutputs, const TF_Output* outputs, @@ -489,9 +508,11 @@ TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, // Do the actual function creation. TF_Function* tf_function = new TF_Function(); + DCHECK(append_hash_to_fn_name <= 1); status->status = tensorflow::GraphToFunctionDef( - fn_body->graph, fn_name, body_nodes, input_tensors, output_tensors, - output_names_vec, description, &tf_function->fdef); + fn_body->graph, fn_name, append_hash_to_fn_name != 0, body_nodes, + input_tensors, output_tensors, output_names_vec, description, + &tf_function->fdef); if (!status->status.ok()) { TF_DeleteFunction(tf_function); return nullptr; @@ -527,10 +548,10 @@ void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, status->status = MessageToBuffer(func->fdef, output_func_def); } -TF_Function* TF_FunctionImportFunctionDef(const TF_Buffer* func_def, +TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len, TF_Status* status) { TF_Function* func = new TF_Function(); - if (!func->fdef.ParseFromArray(func_def->data, func_def->length)) { + if (!func->fdef.ParseFromArray(proto, proto_len)) { status->status = InvalidArgument( "Invalid FunctionDef given to TF_FunctionImportFunctionDef"); TF_DeleteFunction(func); diff --git a/tensorflow/c/c_api_function_test.cc b/tensorflow/c/c_api_function_test.cc index 4ccff3175164730e0c4b15ff1f8e29cbdb665dc3..d5580b658992413ae6f9cb79ef88751ee28ce465 100644 --- a/tensorflow/c/c_api_function_test.cc +++ b/tensorflow/c/c_api_function_test.cc @@ -15,7 +15,6 @@ limitations under the License. #include "tensorflow/c/c_api.h" -#include "tensorflow/c/c_api_internal.h" #include "tensorflow/c/c_test_util.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -179,7 +178,7 @@ class CApiFunctionTest : public ::testing::Test { bool expect_failure = false) { ASSERT_EQ(func_, nullptr); const char** output_names_ptr = ToArray(output_names); - func_ = TF_GraphToFunction(func_graph_, func_name_, num_opers, + func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers, num_opers == -1 ? nullptr : opers.data(), inputs.size(), inputs.data(), outputs.size(), outputs.data(), output_names_ptr, @@ -364,12 +363,10 @@ class CApiFunctionTest : public ::testing::Test { TF_DeleteFunction(func_); // fdef -> func_ - TF_Buffer* buf = TF_NewBuffer(); - Status s = MessageToBuffer(fdef, buf); - ASSERT_EQ(Status::OK(), s) << s.error_message(); - func_ = TF_FunctionImportFunctionDef(buf, s_); + string buf; + ASSERT_TRUE(fdef.SerializeToString(&buf)); + func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); - TF_DeleteBuffer(buf); } void GetAttr(const char* attr_name, AttrValue* out_attr) { @@ -1097,7 +1094,7 @@ TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) { TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); TF_Operation* add = Add(feed1, feed2, func_graph_, s_); DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_)); EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does " "not have output 2\n\tEncountered while processing " "input 1 into function 'MyFunc'"), @@ -1134,7 +1131,7 @@ TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) { TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2"); TF_Operation* add = Add(feed1, feed2, func_graph_, s_); DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true); - EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); + EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_)); EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does " "not have output 3\n\tEncountered while processing " "output 0 from function 'MyFunc'"), @@ -1200,7 +1197,8 @@ TEST_F(CApiFunctionTest, OutputOpNotInBody) { } void DefineFunction(const char* name, TF_Function** func, - const char* description = nullptr) { + const char* description = nullptr, + bool append_hash = false) { std::unique_ptr func_graph( TF_NewGraph(), TF_DeleteGraph); std::unique_ptr s(TF_NewStatus(), @@ -1211,7 +1209,7 @@ void DefineFunction(const char* name, TF_Function** func, TF_Output inputs[] = {{feed, 0}}; TF_Output outputs[] = {{neg, 0}}; - *func = TF_GraphToFunction(func_graph.get(), name, -1, + *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1, /*opers=*/nullptr, 1, inputs, 1, outputs, /*output_names=*/nullptr, /*opts=*/nullptr, description, s.get()); @@ -1405,9 +1403,7 @@ TEST_F(CApiFunctionTest, ImportFunctionDef) { TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) { // Invalid protobuf data (protos cannot start with 4 bytes of zeros) char proto[] = {0x0, 0x0, 0x0, 0x0}; - TF_Buffer* buf = TF_NewBufferFromString(proto, 4); - func_ = TF_FunctionImportFunctionDef(buf, s_); - TF_DeleteBuffer(buf); + func_ = TF_FunctionImportFunctionDef(proto, 4, s_); EXPECT_TRUE(func_ == nullptr); EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_)); EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"), @@ -1453,5 +1449,42 @@ TEST_F(CApiFunctionTest, Description) { ASSERT_EQ(string("Return something"), fdef.signature().description()); } +TEST_F(CApiFunctionTest, Name) { + DefineFunction("long_func_name", &func_, "Return something", + /*append_hash=*/false); + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + ASSERT_EQ(string("long_func_name"), fdef.signature().name()); +} + +TEST_F(CApiFunctionTest, AppendHash) { + DefineFunction("func_name_base", &func_, "Return something", + /*append_hash=*/true); + tensorflow::FunctionDef fdef; + ASSERT_TRUE(GetFunctionDef(func_, &fdef)); + ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name()); +} + +TEST_F(CApiFunctionTest, GetOpDef) { + DefineFunction(func_name_, &func_); + TF_GraphCopyFunction(host_graph_, func_, nullptr, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Test we can retrieve function OpDef from graph + TF_Buffer* buffer = TF_NewBuffer(); + TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Sanity check returned OpDef + string data(static_cast(buffer->data), buffer->length); + OpDef op_def; + op_def.ParseFromString(data); + EXPECT_EQ(op_def.name(), func_name_); + EXPECT_EQ(op_def.input_arg_size(), 1); + EXPECT_EQ(op_def.output_arg_size(), 1); + + TF_DeleteBuffer(buffer); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h index 23ec1fac6f4c623464d6bc93958504a09f3f8876..bb04e01beec931a8ea66d0855eec9625d3a6a5ab 100644 --- a/tensorflow/c/c_api_internal.h +++ b/tensorflow/c/c_api_internal.h @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/c/c_api.h" +#include #include +#include #include #include @@ -124,6 +126,20 @@ struct TF_Session { struct TF_ImportGraphDefOptions { tensorflow::ImportGraphDefOptions opts; + + // Backing memory for TensorId fields in opts. + // TODO(skyewm): it'd be better if ImportGraphDefOptions owned this. + std::list tensor_id_data; +}; + +struct TF_ImportGraphDefResults { + std::vector return_tensors; + std::vector return_nodes; + std::vector unused_key_names; + std::vector unused_key_indexes; + + // Backing memory for unused_key_names values. + std::list unused_key_names_data; }; struct TF_DeviceList { diff --git a/tensorflow/c/c_api_test.cc b/tensorflow/c/c_api_test.cc index c4420290099ee10c89792210dad2604328296515..05881e619ba232de99e78f315cfa8ab9294e5137 100644 --- a/tensorflow/c/c_api_test.cc +++ b/tensorflow/c/c_api_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/error_codes.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" @@ -50,6 +51,11 @@ Status TF_TensorToTensor(const TF_Tensor* src, Tensor* dst); namespace { +static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; +} + TEST(CAPI, Version) { EXPECT_STRNE("", TF_Version()); } TEST(CAPI, Status) { @@ -567,7 +573,7 @@ TEST(CAPI, ImportGraphDef) { TF_GraphToGraphDef(graph, graph_def, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); - // Import it again, with a prefix, in a fresh graph. + // Import it, with a prefix, in a fresh graph. TF_DeleteGraph(graph); graph = TF_NewGraph(); TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); @@ -582,8 +588,8 @@ TEST(CAPI, ImportGraphDef) { ASSERT_TRUE(feed != nullptr); ASSERT_TRUE(neg != nullptr); - // Import it again, with an input mapping and return outputs, into the same - // graph. + // Import it again, with an input mapping, return outputs, and a return + // operation, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); TF_ImportGraphDefOptionsSetPrefix(opts, "imported2"); @@ -591,9 +597,10 @@ TEST(CAPI, ImportGraphDef) { TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); - TF_Output return_outputs[2]; - TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, - return_outputs, 2, s); + TF_ImportGraphDefOptionsAddReturnOperation(opts, "scalar"); + EXPECT_EQ(1, TF_ImportGraphDefOptionsNumReturnOperations(opts)); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); TF_Operation* scalar2 = TF_GraphOperationByName(graph, "imported2/scalar"); @@ -609,11 +616,26 @@ TEST(CAPI, ImportGraphDef) { EXPECT_EQ(0, neg_input.index); // Check return outputs + TF_Output* return_outputs; + int num_return_outputs; + TF_ImportGraphDefResultsReturnOutputs(results, &num_return_outputs, + &return_outputs); + ASSERT_EQ(2, num_return_outputs); EXPECT_EQ(feed2, return_outputs[0].oper); EXPECT_EQ(0, return_outputs[0].index); EXPECT_EQ(scalar, return_outputs[1].oper); // remapped EXPECT_EQ(0, return_outputs[1].index); + // Check return operation + TF_Operation** return_opers; + int num_return_opers; + TF_ImportGraphDefResultsReturnOperations(results, &num_return_opers, + &return_opers); + ASSERT_EQ(1, num_return_opers); + EXPECT_EQ(scalar2, return_opers[0]); // not remapped + + TF_DeleteImportGraphDefResults(results); + // Import again, with control dependencies, into the same graph. TF_DeleteImportGraphDefOptions(opts); opts = TF_NewImportGraphDefOptions(); @@ -683,6 +705,113 @@ TEST(CAPI, ImportGraphDef) { TF_DeleteStatus(s); } +TEST(CAPI, ImportGraphDef_WithReturnOutputs) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph with return outputs. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "feed", 0); + TF_ImportGraphDefOptionsAddReturnOutput(opts, "scalar", 0); + EXPECT_EQ(2, TF_ImportGraphDefOptionsNumReturnOutputs(opts)); + TF_Output return_outputs[2]; + TF_GraphImportGraphDefWithReturnOutputs(graph, graph_def, opts, + return_outputs, 2, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + TF_Operation* feed = TF_GraphOperationByName(graph, "feed"); + TF_Operation* neg = TF_GraphOperationByName(graph, "neg"); + ASSERT_TRUE(scalar != nullptr); + ASSERT_TRUE(feed != nullptr); + ASSERT_TRUE(neg != nullptr); + + // Check return outputs + EXPECT_EQ(feed, return_outputs[0].oper); + EXPECT_EQ(0, return_outputs[0].index); + EXPECT_EQ(scalar, return_outputs[1].oper); + EXPECT_EQ(0, return_outputs[1].index); + + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + +TEST(CAPI, ImportGraphDef_UnusedInputMappings) { + TF_Status* s = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + + // Create a graph with two nodes: x and 3 + Placeholder(graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "feed") != nullptr); + TF_Operation* oper = ScalarConst(3, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "scalar") != nullptr); + Neg(oper, graph, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + ASSERT_TRUE(TF_GraphOperationByName(graph, "neg") != nullptr); + + // Export to a GraphDef. + TF_Buffer* graph_def = TF_NewBuffer(); + TF_GraphToGraphDef(graph, graph_def, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Import it in a fresh graph. + TF_DeleteGraph(graph); + graph = TF_NewGraph(); + TF_ImportGraphDefOptions* opts = TF_NewImportGraphDefOptions(); + TF_GraphImportGraphDef(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + TF_Operation* scalar = TF_GraphOperationByName(graph, "scalar"); + + // Import it in a fresh graph with an unused input mapping. + TF_DeleteImportGraphDefOptions(opts); + opts = TF_NewImportGraphDefOptions(); + TF_ImportGraphDefOptionsSetPrefix(opts, "imported"); + TF_ImportGraphDefOptionsAddInputMapping(opts, "scalar", 0, {scalar, 0}); + TF_ImportGraphDefOptionsAddInputMapping(opts, "fake", 0, {scalar, 0}); + TF_ImportGraphDefResults* results = + TF_GraphImportGraphDefWithResults(graph, graph_def, opts, s); + ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s); + + // Check unused input mappings + int num_unused_input_mappings; + const char** src_names; + int* src_indexes; + TF_ImportGraphDefResultsUnusedInputMappings( + results, &num_unused_input_mappings, &src_names, &src_indexes); + ASSERT_EQ(1, num_unused_input_mappings); + EXPECT_EQ(string("fake"), string(src_names[0])); + EXPECT_EQ(0, src_indexes[0]); + + TF_DeleteImportGraphDefResults(results); + TF_DeleteImportGraphDefOptions(opts); + TF_DeleteBuffer(graph_def); + TF_DeleteGraph(graph); + TF_DeleteStatus(s); +} + TEST(CAPI, Session) { TF_Status* s = TF_NewStatus(); TF_Graph* graph = TF_NewGraph(); @@ -837,6 +966,31 @@ TEST(CAPI, ShapeInferenceError) { TF_DeleteStatus(status); } +TEST(CAPI, GetOpDef) { + TF_Status* status = TF_NewStatus(); + TF_Graph* graph = TF_NewGraph(); + TF_Buffer* buffer = TF_NewBuffer(); + + TF_GraphGetOpDef(graph, "Add", buffer, status); + ASSERT_EQ(TF_OK, TF_GetCode(status)); + const OpDef* expected_op_def; + TF_ASSERT_OK(OpRegistry::Global()->LookUpOpDef("Add", &expected_op_def)); + string expected_serialized; + expected_op_def->SerializeToString(&expected_serialized); + string actual_string(reinterpret_cast(buffer->data), + buffer->length); + EXPECT_EQ(expected_serialized, actual_string); + + TF_GraphGetOpDef(graph, "MyFakeOp", buffer, status); + EXPECT_EQ(TF_NOT_FOUND, TF_GetCode(status)); + ExpectHasSubstr(TF_Message(status), + "Op type not registered 'MyFakeOp' in binary"); + + TF_DeleteBuffer(buffer); + TF_DeleteGraph(graph); + TF_DeleteStatus(status); +} + void StringVectorToArrays(const std::vector& v, std::unique_ptr* ptrs, std::unique_ptr* lens) { diff --git a/tensorflow/c/checkpoint_reader.cc b/tensorflow/c/checkpoint_reader.cc index e7b9bca5b50e4837534c315b8fa2ca161019d100..b1f7bdaa5420a56386e6983052df20aa976aa867 100644 --- a/tensorflow/c/checkpoint_reader.cc +++ b/tensorflow/c/checkpoint_reader.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/c/checkpoint_reader.h" #include +#include #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -24,43 +25,43 @@ limitations under the License. #include "tensorflow/core/util/saved_tensor_slice_util.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; CheckpointReader::CheckpointReader(const string& filename, TF_Status* out_status) - : reader_(nullptr), v2_reader_(nullptr), var_to_shape_map_ptr_(nullptr) { + : reader_(nullptr), + v2_reader_(nullptr), + var_to_shape_map_(nullptr), + var_to_data_type_map_(nullptr) { // Depending on whether this is a V2 ckpt, initializes "reader_" or // "v2_reader_". std::vector v2_path; if (Env::Default()->GetMatchingPaths(MetaFilename(filename), &v2_path).ok() && !v2_path.empty()) { - v2_reader_ = - new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */); + v2_reader_.reset( + new BundleReader(Env::Default(), filename /* prefix to a V2 ckpt */)); if (!v2_reader_->status().ok()) { Set_TF_Status_from_Status(out_status, v2_reader_->status()); return; } - var_to_shape_map_ptr_ = BuildV2VarToShapeMap(); + auto result = BuildV2VarMaps(); + var_to_shape_map_.swap(result.first); + var_to_data_type_map_.swap(result.second); } else { - reader_ = new TensorSliceReader(filename); + reader_.reset(new TensorSliceReader(filename)); if (!reader_->status().ok()) { Set_TF_Status_from_Status(out_status, reader_->status()); return; } - var_to_shape_map_ptr_ = - new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap()); + var_to_shape_map_.reset( + new TensorSliceReader::VarToShapeMap(reader_->GetVariableToShapeMap())); + var_to_data_type_map_.reset(new TensorSliceReader::VarToDataTypeMap( + reader_->GetVariableToDataTypeMap())); } } -CheckpointReader::~CheckpointReader() { - delete var_to_shape_map_ptr_; - delete reader_; - delete v2_reader_; -} - bool CheckpointReader::HasTensor(const string& name) const { if (reader_ != nullptr) { return reader_->HasTensor(name, nullptr, nullptr); @@ -70,8 +71,14 @@ bool CheckpointReader::HasTensor(const string& name) const { const TensorSliceReader::VarToShapeMap& CheckpointReader::GetVariableToShapeMap() const { - CHECK(var_to_shape_map_ptr_); - return *var_to_shape_map_ptr_; + CHECK(var_to_shape_map_); + return *var_to_shape_map_; +} + +const TensorSliceReader::VarToDataTypeMap& +CheckpointReader::GetVariableToDataTypeMap() const { + CHECK(var_to_data_type_map_); + return *var_to_data_type_map_; } const string CheckpointReader::DebugString() const { @@ -100,7 +107,9 @@ void CheckpointReader::GetTensor( } } -TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { +std::pair, + std::unique_ptr> +CheckpointReader::BuildV2VarMaps() { CHECK(v2_reader_ != nullptr); CHECK(v2_reader_->status().ok()); @@ -123,18 +132,23 @@ TensorSliceReader::VarToShapeMap* CheckpointReader::BuildV2VarToShapeMap() { } // Second pass: adds the entries, ignoring the filtered keys. - TensorSliceReader::VarToShapeMap* var_to_shape_map = - new TensorSliceReader::VarToShapeMap; + std::unique_ptr var_to_shape_map( + new TensorSliceReader::VarToShapeMap); + std::unique_ptr var_to_data_type_map( + new TensorSliceReader::VarToDataTypeMap); v2_reader_->Seek(kHeaderEntryKey); for (v2_reader_->Next(); v2_reader_->Valid(); v2_reader_->Next()) { if (filtered_keys.count(v2_reader_->key().ToString()) > 0) continue; CHECK(entry.ParseFromArray(v2_reader_->value().data(), v2_reader_->value().size())) << entry.InitializationErrorString(); - (*var_to_shape_map)[v2_reader_->key().ToString()] = - TensorShape(entry.shape()); + string key = v2_reader_->key().ToString(); + (*var_to_shape_map)[key] = TensorShape(entry.shape()); + (*var_to_data_type_map)[key] = DataType(entry.dtype()); } - return var_to_shape_map; // Owned by caller. + // The returned pointers are owned by the caller. + return std::make_pair(std::move(var_to_shape_map), + std::move(var_to_data_type_map)); } } // namespace checkpoint diff --git a/tensorflow/c/checkpoint_reader.h b/tensorflow/c/checkpoint_reader.h index 1124416380df624f97b3ce2ebaadb04b3c17d341..4de1300a7f66a8b4eb8074819432fd7dd597bb15 100644 --- a/tensorflow/c/checkpoint_reader.h +++ b/tensorflow/c/checkpoint_reader.h @@ -16,6 +16,9 @@ limitations under the License. #ifndef TENSORFLOW_C_CHECKPOINT_READER_H #define TENSORFLOW_C_CHECKPOINT_READER_H +#include +#include + #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/lib/core/status.h" @@ -24,7 +27,6 @@ limitations under the License. #include "tensorflow/core/util/tensor_slice_reader.h" namespace tensorflow { - namespace checkpoint { class TensorSliceReader; @@ -38,15 +40,18 @@ class TensorSliceReader; class CheckpointReader { public: CheckpointReader(const string& filepattern, TF_Status* out_status); - ~CheckpointReader(); bool HasTensor(const string& name) const; const string DebugString() const; - // Returns a map from variable names to its shape. Slices of a partitioned + // Returns a map from variable names to their shapes. Slices of a partitioned // tensor are combined into a single entry. const TensorSliceReader::VarToShapeMap& GetVariableToShapeMap() const; + // Returns a map from variable names to their data types. Slices of a + // partitioned tensor are combined into a single entry. + const TensorSliceReader::VarToDataTypeMap& GetVariableToDataTypeMap() const; + // Attempts to look up the tensor named "name" and stores the found result in // "out_tensor". void GetTensor(const string& name, @@ -54,14 +59,19 @@ class CheckpointReader { TF_Status* out_status) const; private: - // Uses "v2_reader_" to build a "var name -> shape" map; owned by caller. + // Uses "v2_reader_" to build "var name -> shape" and "var name -> data type" + // maps; both owned by caller. // REQUIRES: "v2_reader_ != nullptr && v2_reader_.status().ok()". - TensorSliceReader::VarToShapeMap* BuildV2VarToShapeMap(); + std::pair, + std::unique_ptr > + BuildV2VarMaps(); + + // Invariant: exactly one of "reader_" and "v2_reader_" is non-null. + std::unique_ptr reader_; + std::unique_ptr v2_reader_; - // Invariant: exactly one of "reader_" and "v2_reader_" is non-nullptr. - TensorSliceReader* reader_; // Owned. - BundleReader* v2_reader_; // Owned. - TensorSliceReader::VarToShapeMap* var_to_shape_map_ptr_; // Owned. + std::unique_ptr var_to_shape_map_; + std::unique_ptr var_to_data_type_map_; TF_DISALLOW_COPY_AND_ASSIGN(CheckpointReader); }; diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index 52945d32391ddcd9bddb7726ddac68ee1ba9ae58..c77896b80b478cd34d3502e1061a7e76204ba021 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -3,6 +3,7 @@ licenses(["notice"]) # Apache 2.0 load( "//tensorflow:tensorflow.bzl", + "tf_cuda_cc_test", "tf_cc_test", "tf_copts", "tf_cuda_library", @@ -10,13 +11,15 @@ load( tf_cuda_library( name = "c_api", - srcs = ["c_api.cc"], + srcs = [ + "c_api.cc", + "c_api_internal.h", + ], hdrs = ["c_api.h"], copts = tf_copts(), visibility = ["//visibility:public"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -33,7 +36,22 @@ tf_cuda_library( }), ) -tf_cc_test( +tf_cuda_library( + name = "c_api_internal", + hdrs = ["c_api_internal.h"], + deps = [ + ":c_api", + ":runtime", + "//tensorflow/c:c_api", + "//tensorflow/c:c_api_internal", + "//tensorflow/core:core_cpu_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:framework_lite", + "//tensorflow/core:lib_internal", + ], +) + +tf_cuda_cc_test( name = "c_api_test", srcs = ["c_api_test.cc"], deps = [ @@ -53,7 +71,6 @@ tf_cuda_library( visibility = ["//tensorflow:internal"], deps = select({ "//tensorflow:android": [ - ":c_api_internal", "//tensorflow/core:android_tensorflow_lib_lite", ], "//conditions:default": [ @@ -85,3 +102,14 @@ tf_cc_test( "//tensorflow/core:test_main", ], ) + +cc_library( + name = "tape", + srcs = ["tape.cc"], + hdrs = ["tape.h"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:framework", + "//tensorflow/core:lib", + ], +) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 801d7307494e6585fbb7ee0fa4e6724ebe2c6f94..8359de62b7ff690fec9f6a0e3280f947c62f8b6e 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/c/c_api.h" #include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/c_api_internal.h" #include "tensorflow/c/eager/runtime.h" #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/device_mgr.h" @@ -51,69 +52,25 @@ string DeviceName(tensorflow::Device* d) { } } // namespace -struct TFE_Context { - explicit TFE_Context(TF_Session* s) : session(s) {} - - // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. - TF_Session* session; - tensorflow::Rendezvous* rendezvous; - - tensorflow::mutex functions_mu; - tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ - tensorflow::OpRegistry::Global(), {}}; - - // One FunctionLibraryRuntime per device. - // func_libs[i] is the FunctionLibraryRuntime corresponding to - // session->devices[i]. - std::unique_ptr pflr; +extern "C" { - std::unordered_map - kernel_cache; +TFE_ContextOptions* TFE_NewContextOptions() { return new TFE_ContextOptions; } - tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { - return pflr->GetFLR(d->name()); - } +void TFE_ContextOptionsSetConfig(TFE_ContextOptions* options, const void* proto, + size_t proto_len, TF_Status* status) { + TF_SetConfig(&options->session_options, proto, proto_len, status); +} - const std::vector& devices() { return session->devices; } -}; - -struct TFE_TensorHandle { - TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) - : t(t), d(d) {} - - tensorflow::Tensor t; - // TODO(ashankar): d == nullptr iff local CPU - // This was expedient, but perhaps worth revisiting ('d' should always be a - // valid pointer?) - // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are - // provided with the appropriate TFE_Context. - // - // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a - // TFE_TensorHandle does not outlive the TFE_Context from which it came? - tensorflow::Device* d; -}; - -struct TFE_Op { - TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) - : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} - - bool const is_function() const { return attr_types == nullptr; } - - TFE_Context* ctx; // Must outlive the TFE_Op. - const string name; - tensorflow::AttrBuilder attrs; - const tensorflow::AttrTypeMap* attr_types; - std::vector inputs; - std::vector input_devices; - tensorflow::Device* device; -}; +void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions* options, TFE_ContextDevicePlacementPolicy policy) { + options->policy = policy; +} -extern "C" { +void TFE_DeleteContextOptions(TFE_ContextOptions* options) { delete options; } -TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { +TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) { TF_Graph* graph = TF_NewGraph(); - TF_Session* session = TF_NewSession(graph, opts, status); + TF_Session* session = TF_NewSession(graph, &opts->session_options, status); if (status->status.ok()) { if (session->device_mgr == nullptr || session->devices.empty()) { status->status = tensorflow::errors::InvalidArgument( @@ -128,9 +85,10 @@ TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, TF_Status* status) { } TFE_Context* ret = new TFE_Context(session); + ret->policy = opts->policy; ret->pflr.reset(new tensorflow::ProcessFunctionLibraryRuntime( - ret->session->device_mgr, opts->options.env, TF_GRAPH_DEF_VERSION, - &ret->func_lib_def, {})); + ret->session->device_mgr, opts->session_options.options.env, + TF_GRAPH_DEF_VERSION, &ret->func_lib_def, {})); ret->rendezvous = new tensorflow::IntraProcessRendezvous(ret->session->device_mgr); @@ -330,6 +288,20 @@ TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, return ret; } +TF_AttrType TFE_OpNameGetAttrType(TFE_Context* ctx, + const char* op_or_function_name, + const char* attr_name, unsigned char* is_list, + TF_Status* status) { + TF_AttrType ret; + TFE_Op* op = TFE_NewOp(ctx, op_or_function_name, status); + if (!status->status.ok()) { + return TF_ATTR_INT; // Same dummy return as TFE_OpGetAttrType. + } + ret = TFE_OpGetAttrType(op, attr_name, is_list, status); + TFE_DeleteOp(op); + return ret; +} + void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value) { op->attrs.Set(attr_name, value); } @@ -451,8 +423,10 @@ void TFE_OpSetAttrShapeList(TFE_Op* op, const char* attr_name, namespace { tensorflow::Status ValidateInputTypeAndPlacement( - tensorflow::Device* host_device, tensorflow::Device* op_device, TFE_Op* op, - const tensorflow::OpKernel* kernel) { + TFE_Context* ctx, tensorflow::Device* host_device, + tensorflow::Device* op_device, TFE_Op* op, + const tensorflow::OpKernel* kernel, + std::vector* copied_tensors) { const tensorflow::MemoryTypeVector& memtypes = kernel->input_memory_types(); if (memtypes.size() != op->inputs.size()) { return tensorflow::errors::InvalidArgument( @@ -464,11 +438,50 @@ tensorflow::Status ValidateInputTypeAndPlacement( const tensorflow::Device* actual_device = op->input_devices[i] == nullptr ? host_device : op->input_devices[i]; if (expected_device != actual_device) { - return tensorflow::errors::InvalidArgument( - "cannot compute ", op->name, " as input #", i, - " was expected to be on ", expected_device->name(), - " but is actually on ", actual_device->name(), - " (operation running on ", op_device->name(), ")"); + switch (ctx->policy) { + case TFE_DEVICE_PLACEMENT_EXPLICIT: + // TODO(xpan): See if we could bubble python related error up + // to python level. + return tensorflow::errors::InvalidArgument( + "Tensors on conflicting devices:" + " cannot compute ", + op->name, " as input #", i, " was expected to be on ", + expected_device->name(), " but is actually on ", + actual_device->name(), " (operation running on ", + op_device->name(), ")", + " Tensors can be copied explicitly using .gpu() or .cpu()," + " or transparently copied by using tfe.enable_eager_execution(" + "tfe.DEVICE_PLACEMENT_SILENT). Copying tensors between devices" + " may slow down your model"); + case TFE_DEVICE_PLACEMENT_WARN: + LOG(WARNING) << "before computing " << op->name << " input #" << i + << " was expected to be on " << expected_device->name() + << " but is actually on " << actual_device->name() + << " (operation running on " << op_device->name() + << "). This triggers a copy which can be a performance " + "bottleneck."; + break; + case TFE_DEVICE_PLACEMENT_SILENT: // Do nothing. + break; + } + // We are only here if the policy is warn or silent copies, so we should + // trigger a copy. + TFE_TensorHandle original{op->inputs[i], op->input_devices[i]}; + TF_Status* s = TF_NewStatus(); + TFE_TensorHandle* copied_tensor = TFE_TensorHandleCopyToDevice( + &original, ctx, expected_device->name().c_str(), s); + if (!s->status.ok()) { + tensorflow::Status status = s->status; + delete s; + return tensorflow::errors::Internal( + "Failed copying input tensor from ", actual_device->name(), " to ", + expected_device->name(), " in order to run ", op->name, ": ", + status.error_message()); + } + op->inputs[i] = copied_tensor->t; + copied_tensors->push_back(copied_tensor); + op->input_devices[i] = copied_tensor->d; + delete s; } if (op->inputs[i].dtype() != kernel->input_type(i)) { return tensorflow::errors::InvalidArgument( @@ -511,10 +524,14 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, } tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel); } - status->status = ValidateInputTypeAndPlacement(ctx->devices()[0], device, op, - kernel->kernel()); + std::vector copied_tensors; + status->status = ValidateInputTypeAndPlacement( + ctx, ctx->devices()[0], device, op, kernel->kernel(), &copied_tensors); output_memory_types = &kernel->kernel()->output_memory_types(); if (!status->status.ok()) { + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } return; } // WARNING: kernel->Run utilizes the FunctionLibraryRuntime @@ -526,6 +543,9 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals, // sense for FunctionLibraryRuntime to ensure thread-safe access to // FunctionLibraryDefinition?). status->status = kernel->Run(&op->inputs, &outputs); + for (auto* t : copied_tensors) { + TFE_DeleteTensorHandle(t); + } if (!status->status.ok()) return; *num_retvals = std::min(*num_retvals, outputs.size()); for (int i = 0; i < *num_retvals; ++i) { diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h index a4f7d308fbb4008d00bd97abf40c9ead5fdb1986..865580c5f3a823d9cf49fe460bd007e3b3b88767 100644 --- a/tensorflow/c/eager/c_api.h +++ b/tensorflow/c/eager/c_api.h @@ -43,14 +43,46 @@ limitations under the License. extern "C" { #endif +typedef struct TFE_ContextOptions TFE_ContextOptions; + +// Return a new options object. +TF_CAPI_EXPORT extern TFE_ContextOptions* TFE_NewContextOptions(); + +// Set the config in TF_ContextOptions.options. +// config should be a serialized tensorflow.ConfigProto proto. +// If config was not parsed successfully as a ConfigProto, record the +// error information in *status. +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetConfig( + TFE_ContextOptions* options, const void* proto, size_t proto_len, + TF_Status* status); + +// Controls how to act when we try to run an operation on a given device but +// some input tensors are not on that device. +typedef enum TFE_ContextDevicePlacementPolicy { + // The default: running operations with input tensors on the wrong device will + // fail. + TFE_DEVICE_PLACEMENT_EXPLICIT = 0, + // Copy the tensor to the right device but log a warning. + TFE_DEVICE_PLACEMENT_WARN = 1, + // Silently copy the tensor, which has a performance cost since the + // operation will be blocked till the copy completes. + TFE_DEVICE_PLACEMENT_SILENT = 2, +} TFE_ContextDevicePlacementPolicy; + +TF_CAPI_EXPORT extern void TFE_ContextOptionsSetDevicePlacementPolicy( + TFE_ContextOptions*, TFE_ContextDevicePlacementPolicy); + +// Destroy an options object. +TF_CAPI_EXPORT extern void TFE_DeleteContextOptions(TFE_ContextOptions*); + // "Context" under which operations/functions are executed. It encapsulates // things like the available devices, resource manager etc. // // TODO(ashankar): Merge with TF_Session? typedef struct TFE_Context TFE_Context; -TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext(const TF_SessionOptions* opts, - TF_Status* status); +TF_CAPI_EXPORT extern TFE_Context* TFE_NewContext( + const TFE_ContextOptions* opts, TF_Status* status); TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status); TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status); @@ -107,6 +139,12 @@ TF_CAPI_EXPORT extern void TFE_OpAddInput(TFE_Op* op, TFE_TensorHandle* h, TF_St TF_CAPI_EXPORT extern TF_AttrType TFE_OpGetAttrType(TFE_Op* op, const char* attr_name, unsigned char* is_list, TF_Status* status); +// Get an attribute type given an op name; a fusion of TFE_NewOp and +// TFE_OpGetAttrType for use from Python without the overhead of the individual +// calls and memory management of TFE_Op. +TF_CAPI_EXPORT extern TF_AttrType TFE_OpNameGetAttrType( + TFE_Context* ctx, const char* op_or_function_name, const char* attr_name, + unsigned char* is_list, TF_Status* status); TF_CAPI_EXPORT extern void TFE_OpSetAttrString(TFE_Op* op, const char* attr_name, const char* value); diff --git a/tensorflow/c/eager/c_api_internal.h b/tensorflow/c/eager/c_api_internal.h new file mode 100644 index 0000000000000000000000000000000000000000..0971e2ab2fe98cc8bf6f631f41d5adce90ee7051 --- /dev/null +++ b/tensorflow/c/eager/c_api_internal.h @@ -0,0 +1,103 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ +#define TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ + +#include "tensorflow/c/eager/c_api.h" + +#include +#include +#include +#include +#include + +#include "tensorflow/c/c_api.h" +#include "tensorflow/c/c_api_internal.h" +#include "tensorflow/c/eager/runtime.h" +#include "tensorflow/core/common_runtime/device_factory.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/rendezvous_mgr.h" +#include "tensorflow/core/framework/rendezvous.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/gtl/stl_util.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/thread_annotations.h" + +struct TFE_ContextOptions { + TF_SessionOptions session_options; + TFE_ContextDevicePlacementPolicy policy{TFE_DEVICE_PLACEMENT_EXPLICIT}; +}; + +struct TFE_Context { + explicit TFE_Context(TF_Session* s) : session(s) {} + + TFE_ContextDevicePlacementPolicy policy; + + // TFE_Context is an extension of TF_Session. And TF_Session needs a TF_Graph. + TF_Session* session; + tensorflow::Rendezvous* rendezvous; + + tensorflow::mutex functions_mu; + tensorflow::FunctionLibraryDefinition func_lib_def GUARDED_BY(functions_mu){ + tensorflow::OpRegistry::Global(), {}}; + + // One FunctionLibraryRuntime per device. + // func_libs[i] is the FunctionLibraryRuntime corresponding to + // session->devices[i]. + std::unique_ptr pflr; + + std::unordered_map + kernel_cache; + + tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) { + return pflr->GetFLR(d->name()); + } + + const std::vector& devices() { return session->devices; } +}; + +struct TFE_TensorHandle { + TFE_TensorHandle(const tensorflow::Tensor& t, tensorflow::Device* d) + : t(t), d(d) {} + + tensorflow::Tensor t; + // TODO(ashankar): d == nullptr iff local CPU + // This was expedient, but perhaps worth revisiting ('d' should always be a + // valid pointer?) + // This can be done if TFE_NewOp() and the TFE_TensorHandle constructors are + // provided with the appropriate TFE_Context. + // + // TODO(ashankar): Reference count TFE_Context to ensure that 'd' of a + // TFE_TensorHandle does not outlive the TFE_Context from which it came? + tensorflow::Device* d; +}; + +struct TFE_Op { + TFE_Op(TFE_Context* ctx, const char* op, const tensorflow::AttrTypeMap* t) + : ctx(ctx), name(op), attrs(op), attr_types(t), device(nullptr) {} + + bool const is_function() const { return attr_types == nullptr; } + + TFE_Context* ctx; // Must outlive the TFE_Op. + const tensorflow::string name; + tensorflow::AttrBuilder attrs; + const tensorflow::AttrTypeMap* attr_types; + std::vector inputs; + std::vector input_devices; + tensorflow::Device* device; +}; + +#endif // TENSORFLOW_C_EAGER_C_API_INTERNAL_H_ diff --git a/tensorflow/c/eager/c_api_test.cc b/tensorflow/c/eager/c_api_test.cc index 72e0fe8a1565a9a717c01aed83044cab2dd2dfbc..4af91b8853d0e85570bad136752a9d0a04b87da5 100644 --- a/tensorflow/c/eager/c_api_test.cc +++ b/tensorflow/c/eager/c_api_test.cc @@ -62,10 +62,10 @@ TFE_Op* MatMulOp(TFE_Context* ctx, TFE_TensorHandle* a, TFE_TensorHandle* b) { void BM_InitOp(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); tensorflow::testing::StartTiming(); @@ -84,10 +84,10 @@ BENCHMARK(BM_InitOp); void BM_Execute(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -109,9 +109,9 @@ BENCHMARK(BM_Execute); TEST(CAPI, Context) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TF_DeviceList* devices = TFE_ContextListDevices(ctx, status); EXPECT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -150,9 +150,9 @@ TEST(CAPI, TensorHandle) { TEST(CAPI, TensorHandleCopyBetweenDevices) { std::unique_ptr status( TF_NewStatus(), TF_DeleteStatus); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status.get()); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); @@ -216,12 +216,58 @@ TEST(CAPI, TensorHandleCopyBetweenDevices) { EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); } +TEST(CAPI, TensorHandleSilentCopy) { + std::unique_ptr status( + TF_NewStatus(), TF_DeleteStatus); + TFE_ContextOptions* opts = TFE_NewContextOptions(); + TFE_ContextOptionsSetDevicePlacementPolicy(opts, TFE_DEVICE_PLACEMENT_SILENT); + TFE_Context* ctx = TFE_NewContext(opts, status.get()); + TFE_DeleteContextOptions(opts); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TFE_TensorHandle* hcpu = TestMatrixTensorHandle(); + TF_Tensor* t = TFE_TensorHandleResolve(hcpu, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + + TF_DeviceList* devices = TFE_ContextListDevices(ctx, status.get()); + ASSERT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); + const int num_devices = TF_DeviceListCount(devices); + + // Disable the test if no GPU is present. + if (num_devices > 1) { + const int device_to_use = 1; + const string name(TF_DeviceListName(devices, device_to_use, status.get())); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_TensorHandle* hgpu = + TFE_TensorHandleCopyToDevice(hcpu, ctx, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + + TFE_Op* matmul = MatMulOp(ctx, hcpu, hgpu); + TFE_OpSetDevice(matmul, name.c_str(), status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_TensorHandle* retvals[1]; + int num_retvals = 1; + TFE_Execute(matmul, &retvals[0], &num_retvals, status.get()); + ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get()); + TFE_DeleteOp(matmul); + TFE_DeleteTensorHandle(retvals[0]); + TFE_DeleteTensorHandle(hgpu); + } + + TF_DeleteDeviceList(devices); + TF_DeleteTensor(t); + TFE_DeleteTensorHandle(hcpu); + TFE_DeleteContext(ctx, status.get()); + EXPECT_EQ(TF_OK, TF_GetCode(status.get())) << TF_Message(status.get()); +} + TEST(CAPI, Execute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* m = TestMatrixTensorHandle(); TFE_Op* matmul = MatMulOp(ctx, m, m); @@ -285,10 +331,10 @@ string MatMulFunction() { TEST(CAPI, FunctionDefAndExecute) { TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -326,10 +372,10 @@ TEST(CAPI, FunctionDefAndExecute) { void BM_ExecuteFunction(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); string function_def = MatMulFunction(); TFE_ContextAddFunctionDef(ctx, function_def.data(), function_def.size(), @@ -406,10 +452,10 @@ TEST(CAPI, Variables) { // Variables use resource handles, so this is really a test for resource // tensor handling. TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 12.0, status); ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); @@ -446,10 +492,10 @@ TEST(CAPI, Variables) { void BM_ReadVariable(int iters) { tensorflow::testing::StopTiming(); TF_Status* status = TF_NewStatus(); - TF_SessionOptions* opts = TF_NewSessionOptions(); + TFE_ContextOptions* opts = TFE_NewContextOptions(); TFE_Context* ctx = TFE_NewContext(opts, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); - TF_DeleteSessionOptions(opts); + TFE_DeleteContextOptions(opts); TFE_TensorHandle* var_handle = CreateVariable(ctx, 5.0, status); CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status); diff --git a/tensorflow/c/eager/tape.cc b/tensorflow/c/eager/tape.cc new file mode 100644 index 0000000000000000000000000000000000000000..464612a81ebda428f5582b6927f3a3b00a5aa6f5 --- /dev/null +++ b/tensorflow/c/eager/tape.cc @@ -0,0 +1,102 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/eager/tape.h" + +namespace tensorflow { +namespace eager { + +bool GradientTape::ShouldRecord(gtl::ArraySlice tensor_ids) { + for (int64 i : tensor_ids) { + if (tensor_tape_.find(i) != tensor_tape_.end()) { + return true; + } + } + return false; +} + +void GradientTape::Watch(int64 tensor_id) { + tensor_tape_.emplace(tensor_id, -1); +} + +void GradientTape::RecordOperation( + const string& op_type, gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, void* backward_function, + const std::function& backward_function_deleter) { + if (!ShouldRecord(input_tensor_id)) { + backward_function_deleter(); + return; + } + std::vector ids; + ids.reserve(input_tensor_id.size()); + for (int64 i : input_tensor_id) { + tensor_usage_[i]++; + ids.push_back(i); + } + const int64 op_id = next_op_id_++; + std::vector tensors; + tensors.reserve(output_tensors.size()); + for (const TapeTensor& o : output_tensors) { + // Note: the tensor can have already been watched and hence be in the tape, + // so we cannot check that we're inserting it here. + tensor_tape_[o.id] = op_id; + tensor_usage_[o.id] = 1; + tensors.push_back(o); + } + op_tape_[op_id] = OpTapeEntry{op_type, tensors, ids, backward_function, + backward_function_deleter}; +} + +void GradientTape::DeleteTrace(int64 tensor_id) { + auto it = tensor_usage_.find(tensor_id); + if (it == tensor_usage_.end()) { + return; + } + it->second--; + if (it->second != 0) { + return; + } + tensor_usage_.erase(it); + auto tensor_op_it = tensor_tape_.find(tensor_id); + if (tensor_op_it == tensor_tape_.end()) { + return; + } + const int64 op_id = tensor_op_it->second; + if (op_id == -1) { + // Do not delete watched tensors. + return; + } + tensor_tape_.erase(tensor_op_it); + auto op_it = op_tape_.find(op_id); + CHECK(op_it != op_tape_.end()); + for (const auto& output : op_it->second.output_tensor_info) { + if (tensor_usage_.find(output.id) != tensor_usage_.end()) { + // Found a usage for an output, so cannot delete the op. + return; + } + } + for (int64 id : op_it->second.input_tensor_id) { + DeleteTrace(id); + } + op_it->second.backward_function_deleter(); + op_tape_.erase(op_it); +} + +std::pair GradientTape::Export() { + return {std::move(tensor_tape_), std::move(op_tape_)}; +} + +} // namespace eager +} // namespace tensorflow diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h new file mode 100644 index 0000000000000000000000000000000000000000..df51f300eb61d54cb1e06d5a58a9b10e834f73c4 --- /dev/null +++ b/tensorflow/c/eager/tape.h @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_C_EAGER_TAPE_H_ +#define TENSORFLOW_C_EAGER_TAPE_H_ + +// Language-agnostic gradient tape. Does not perform backpropagation, just +// maintains the data structures required to do so. + +#include +#include +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/types.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace eager { + +// Information about a tensor. +struct TapeTensor { + int64 id; // Expected to be unique in the lifetime of this process. + DataType dtype; + TensorShape shape; +}; + +// Represents an entry in the tape. +struct OpTapeEntry { + string op_type; + std::vector output_tensor_info; + std::vector input_tensor_id; + + // TODO(apassos) consider narrowing down this interface. + void* backward_function; + + // Should be called before deleting the backward function. TODO(apassos) use + // unique_ptrs to ensure this happens. + std::function backward_function_deleter; +}; + +// Map from tensor_id to internally-defined operation-id of the operation which +// produced this tensor. A value of -1 means that the tensor was directly +// watched and not the result of any operation in the tape. +using TensorTape = std::unordered_map; + +// Map from operation-id to tape entry. +using OpTape = std::unordered_map; + +// Traces the execution of operations, doing eager garbage collection, and +// exporting a full trace so other code can do backpropagation. Not thread-safe. +class GradientTape { + public: + GradientTape() {} + + bool ShouldRecord(gtl::ArraySlice tensor_ids); + + void Watch(int64 tensor_id); + + void RecordOperation(const string& op_type, + gtl::ArraySlice output_tensors, + gtl::ArraySlice input_tensor_id, + void* backward_function, + const std::function& backward_function_deleter); + + void DeleteTrace(int64 tensor_id); + + // Note: it is only valid to call Export once per tape, and after calling + // export the tape is no longer valid (i.e. calls to ShouldRecord, Watch, + // Record, and Delete have undefined behavior). + std::pair Export(); + + private: + TensorTape tensor_tape_; + OpTape op_tape_; + int64 next_op_id_{0}; + + // Map from tensor id to number of remaining usages (i.e. how many entries in + // the tape refer to it); to aid in tape garbage collection. + std::unordered_map tensor_usage_; +}; + +} // namespace eager +} // namespace tensorflow + +#endif // TENSORFLOW_C_EAGER_TAPE_H_ diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc index b8d36b894722304e2b5e97332cabd5bab3c6dbd4..c67007dca0a2d3e97d367ef0eae2335e5683d087 100644 --- a/tensorflow/c/python_api.cc +++ b/tensorflow/c/python_api.cc @@ -24,9 +24,30 @@ void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) { graph->graph.AddControlEdge(&input->node, &op->node); } +void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status) { + AttrValue attr_val; + if (!attr_val.ParseFromArray(attr_value_proto->data, + attr_value_proto->length)) { + status->status = + tensorflow::errors::InvalidArgument("Invalid AttrValue proto"); + return; + } + + mutex_lock l(graph->mu); + op->node.AddAttr(attr_name, attr_val); +} + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) { mutex_lock l(graph->mu); op->node.set_requested_device(device); } +void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status) { + mutex_lock l(graph->mu); + status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index, + &dst.oper->node, dst.index); +} + } // namespace tensorflow diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h index e1a55d7755a76c778bf6a8120a8cf81adb6941dc..f54585b0a1034ff108202272a11416e34985959e 100644 --- a/tensorflow/c/python_api.h +++ b/tensorflow/c/python_api.h @@ -25,8 +25,16 @@ namespace tensorflow { void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input); +// Changes an attr value in the node_def Protocol Buffer and sets a status upon +// completion. +void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name, + TF_Buffer* attr_value_proto, TF_Status* status); + void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device); +void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst, + TF_Status* status); + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_C_PYTHON_API_H_ diff --git a/tensorflow/c/while_loop_test.cc b/tensorflow/c/while_loop_test.cc index 27be5d787f6e14b61ca90ddb81fb42640b9db721..d2d887f32c44af5980b50785f282187d0f6fcff4 100644 --- a/tensorflow/c/while_loop_test.cc +++ b/tensorflow/c/while_loop_test.cc @@ -73,6 +73,11 @@ class CApiWhileLoopTest : public ::testing::Test { } void Run(std::initializer_list input_values) { + Run(outputs_, input_values); + } + + void Run(const std::vector& run_outputs, + std::initializer_list input_values) { DCHECK_EQ(inputs_.size(), input_values.size()); std::vector> inputs(inputs_.size()); int i = 0; @@ -80,9 +85,10 @@ class CApiWhileLoopTest : public ::testing::Test { inputs[i] = {inputs_[i].oper, Int32Tensor(v)}; ++i; } + // TODO(skyewm): use std::make_unique or absl::make_unique when possible. csession_.reset(new CSession(graph_, s_)); csession_->SetInputs(inputs); - csession_->SetOutputs(outputs_); + csession_->SetOutputs(run_outputs); csession_->Run(s_); ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); } @@ -312,7 +318,7 @@ TEST_F(CApiWhileLoopTest, InvalidCondOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, InvalidCondOutputIndex) { @@ -352,7 +358,7 @@ TEST_F(CApiWhileLoopTest, InvalidBodyOutputNode) { // TODO(skyewm): this error message could be more informative. Add explicit // checks for this case in the while loop implementation? ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } // TODO(skyewm): enable this when it works (currently segfaults!) @@ -383,7 +389,7 @@ TEST_F(CApiWhileLoopTest, WrongGraph) { params_->body_outputs[0] = inputs_[0]; // TODO(skyewm): improve error message ExpectError(TF_INVALID_ARGUMENT, - "Requested return node 'p0' not found in graph def"); + "Requested return tensor 'p0:0' not found in graph def"); } TEST_F(CApiWhileLoopTest, BadTypes) { @@ -402,4 +408,36 @@ TEST_F(CApiWhileLoopTest, BadTypes) { TF_AbortWhile(params_.get()); } +// This is a basic test to make sure the C++ gradient code can handle while +// loops created by the C API (which calls the C++ API under the hood). There +// are more while loop gradient tests in cc/framework/while_gradients_test.cc. +TEST_F(CApiWhileLoopTest, Gradients) { + Init(1); + + // Create loop: while (i < 10) i += 1 + TF_Operation* ten = ScalarConst(10, params_->cond_graph, s_); + TF_Operation* less_than = + LessThan(params_->cond_inputs[0], {ten, 0}, params_->cond_graph, s_); + DCHECK_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->cond_output = {less_than, 0}; + + TF_Operation* one = ScalarConst(1, params_->body_graph, s_); + TF_Operation* add = + Add(params_->body_inputs[0], {one, 0}, params_->body_graph, s_); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + params_->body_outputs[0] = {add, 0}; + + ExpectOK(); + + // Create backprop graph + TF_Output grad_output; + TF_AddGradients(graph_, outputs_.data(), outputs_.size(), inputs_.data(), 1, + nullptr, s_, &grad_output); + ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_); + + // Run gradient + Run({grad_output}, {0}); + ExpectOutputValue(0, 1); +} + } // namespace diff --git a/tensorflow/cc/BUILD b/tensorflow/cc/BUILD index b0c8cc3d0a60ebf5a86b82cb8abc4327e771212e..80112f9b44b1d5fd65a7d47788b072dc47a2b29a 100644 --- a/tensorflow/cc/BUILD +++ b/tensorflow/cc/BUILD @@ -19,13 +19,20 @@ load( cc_library( name = "gradients", - srcs = ["framework/gradients.cc"], + srcs = [ + "framework/gradients.cc", + "framework/while_gradients.cc", + "framework/while_gradients.h", + ], hdrs = ["framework/gradients.h"], deps = [ ":cc_ops", + ":cc_ops_internal", ":grad_op_registry", ":ops", ":scope", + ":scope_internal", + ":while_loop", "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -38,10 +45,33 @@ tf_cc_test( srcs = ["framework/gradients_test.cc"], deps = [ ":cc_ops", + ":client_session", + ":grad_op_registry", + ":grad_ops", + ":gradients", + ":testutil", + "//tensorflow/core:all_kernels", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +tf_cc_test( + name = "framework_while_gradients_test", + size = "small", + srcs = ["framework/while_gradients_test.cc"], + deps = [ + ":cc_ops", + ":client_session", ":grad_op_registry", ":grad_ops", ":gradients", ":testutil", + ":while_loop", "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", diff --git a/tensorflow/cc/framework/gradients.cc b/tensorflow/cc/framework/gradients.cc index b665ce744d77cba8e71a047b33060b420f6343c2..affd90b1bcc7cb4a8b3ffed6aeeb4bd480f5e314 100644 --- a/tensorflow/cc/framework/gradients.cc +++ b/tensorflow/cc/framework/gradients.cc @@ -16,8 +16,9 @@ limitations under the License. #include #include -#include "tensorflow/cc/framework/gradients.h" #include "tensorflow/cc/framework/grad_op_registry.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/while_gradients.h" #include "tensorflow/cc/ops/standard_ops.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" @@ -25,6 +26,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/graph/algorithm.h" #include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/while_context.h" #include "tensorflow/core/lib/gtl/map_util.h" #include "tensorflow/core/platform/macros.h" @@ -82,6 +84,20 @@ class SymbolicGradientBuilder { // from outputs_. Keyed by node id. std::vector GetReachableNodes(); + // Creates the gradient subgraph for a while loop (or just stores + // `summed_grads` if not all incoming gradients are available yet). All exit + // nodes (which are the first nodes of a loop encountered in the backwards + // pass) are passed to this function rather than processed normally. + // `summed_grads` is the sum of `exit_node`s gradients. + Status ProcessWhileLoop(Node* exit_node, const Output& summed_grads); + + // Gets the set of node ids at which to stop backprop. These are all elements + // of `outputs_` that do not get transitively consumed by other `outputs_`. + // Used to identify nodes at which to stop backprop. + std::unordered_set GetStopBackpropNodes( + const std::vector& reachable_nodes, + std::unordered_set output_nodes); + const Scope& scope_; const ops::GradOpRegistry* registry_; const std::vector& outputs_; @@ -89,14 +105,13 @@ class SymbolicGradientBuilder { const std::vector& grad_inputs_; std::vector* grad_outputs_; - // A vector of output endpoints which represents backpropagated - // gradients - typedef std::vector BackpropedGradients; + // A vector of output endpoints which represents backpropagated gradients. + typedef std::vector BackproppedGradients; // backprops_ is a map from a node output to its accumulated // gradients. When a node output has accumulated all its // gradients, we add a node which sums them up. - std::unordered_map + std::unordered_map backprops_; // pending[i] is count-down counter for i-th node's expected @@ -109,14 +124,16 @@ class SymbolicGradientBuilder { // gradients from `grad_inputs_`. std::deque ready_; - // The set of node ids in `outputs_`. Used to identify nodes at which to stop - // backprop. - std::unordered_set output_nodes_; - // The set of node ids in `inputs_`. Used to identify nodes at backprop // frontier. Maps from Output -> index into `grad_outputs_`. std::unordered_map input_nodes_; + // For each while loop in the graph, collects the summed gradients for each of + // the loop's exit nodes. Note that unlike backprops_, this map contains the + // output of SumGradients(), not the input (i.e. each exit node may have + // multiple incoming gradients, but we only store the combined Output here). + std::map> while_backprops_; + TF_DISALLOW_COPY_AND_ASSIGN(SymbolicGradientBuilder); }; @@ -150,6 +167,7 @@ Status SymbolicGradientBuilder::BackpropAlongEdge(const Output& dst_grad, std::vector SymbolicGradientBuilder::GetReachableNodes() { std::vector reachable_nodes(scope_.graph()->num_node_ids(), false); std::deque queue; + std::vector visited(scope_.graph()->num_node_ids(), false); for (const Output& out : outputs_) { if (!reachable_nodes[out.node()->id()]) { queue.push_back(out.node()); @@ -162,13 +180,72 @@ std::vector SymbolicGradientBuilder::GetReachableNodes() { queue.pop_front(); for (const Edge* e : n->in_edges()) { if (e->IsControlEdge()) continue; + if (visited[e->src()->id()]) continue; queue.push_back(e->src()); reachable_nodes[e->src()->id()] = true; + visited[e->src()->id()] = true; } } return reachable_nodes; } +std::unordered_set SymbolicGradientBuilder::GetStopBackpropNodes( + const std::vector& reachable_nodes, + std::unordered_set output_nodes) { + // Output nodes that get transitively consumed by other `outputs_` are stored + // in `internal_outputs`. + std::unordered_set internal_outputs; + std::unordered_set visited; + // Initialize `queue` for BFS traversal. Nodes in `queue` hold upcoming nodes + // along with the last Node in `output_` encountered along that path. If no + // `output_` node was encountered, pair.second will be nullptr. + std::deque> queue; + for (const Output& nout : inputs_) { + if (visited.find(nout.node()) == visited.end()) { + queue.push_back(std::make_pair(nout.node(), static_cast(nullptr))); + visited.insert(nout.node()); + } + } + // BFS from nodes in 'inputs_' along out edges for the entire graph. Internal + // output nodes are recorded during the traversal. All nodes that are output + // nodes but not internal output nodes are considered the frontier of the + // output nodes, and thus our stop backprop nodes. + while (!queue.empty()) { + std::pair p = queue.front(); + Node* n = p.first; + queue.pop_front(); + for (const Edge* e : n->out_edges()) { + // If a node is not reachable from outputs_, we can stop. + if (e->IsControlEdge() || !reachable_nodes[e->dst()->id()]) continue; + if (visited.find(e->dst()) != visited.end()) continue; + + int node_id = e->dst()->id(); + Node* last_output_node = p.second; + if (output_nodes.find(node_id) != output_nodes.end()) { + // We reached an output node. + if (last_output_node != nullptr) { + // If we had already found an output node on this path so we mark + // it as an internal output. + internal_outputs.insert(last_output_node->id()); + } + // Mark this newly found output node to insert in the queue. + last_output_node = e->dst(); + } + queue.push_back(std::make_pair(e->dst(), last_output_node)); + visited.insert(e->dst()); + } + } + // Finally, we set stop_backprop_nodes to all output_nodes that aren't also + // internal_outputs. + std::unordered_set stop_backprop_nodes; + for (int output_node : output_nodes) { + if (internal_outputs.find(output_node) == internal_outputs.end()) { + stop_backprop_nodes.insert(output_node); + } + } + return stop_backprop_nodes; +} + Status SymbolicGradientBuilder::Initialize() { if (outputs_.size() != grad_inputs_.size()) { return errors::InvalidArgument( @@ -185,11 +262,16 @@ Status SymbolicGradientBuilder::Initialize() { } grad_outputs_->clear(); grad_outputs_->resize(inputs_.size()); - // Populate `output_nodes_` from node ids in `outputs_`. - output_nodes_.reserve(outputs_.size()); + + std::unordered_set output_nodes; + output_nodes.reserve(outputs_.size()); for (size_t i = 0; i < outputs_.size(); ++i) { - output_nodes_.insert(outputs_[i].node()->id()); + output_nodes.insert(outputs_[i].node()->id()); } + + std::unordered_set stop_backprop_nodes = + GetStopBackpropNodes(reachable_nodes, output_nodes); + // Populate `input_nodes_` from Outputs in `inputs_`. input_nodes_.reserve(inputs_.size()); for (size_t i = 0; i < inputs_.size(); ++i) { @@ -220,7 +302,7 @@ Status SymbolicGradientBuilder::Initialize() { backprops_[{n, i}].clear(); } int num_expected_backprops = 0; - if (output_nodes_.find(n->id()) == output_nodes_.end()) { + if (stop_backprop_nodes.find(n->id()) == stop_backprop_nodes.end()) { // Internal node: continue BFS along connected outputs. for (const Edge* e : n->out_edges()) { // If a node is not reachable from outputs_, @@ -233,9 +315,10 @@ Status SymbolicGradientBuilder::Initialize() { } ++num_expected_backprops; } - } else { - // Output node: stop BFS and update `num_expected_backprops` for - // each Output in `outputs_` that references `n`. + } + if (output_nodes.find(n->id()) != output_nodes.end()) { + // Output node: update `num_expected_backprops` for each Output in + // `outputs_` that references `n`. for (const Output& output : outputs_) { if (output.node() == n) { ++num_expected_backprops; @@ -304,6 +387,53 @@ Status SymbolicGradientBuilder::CallGradFunction( return Status::OK(); } +Status SymbolicGradientBuilder::ProcessWhileLoop(Node* exit_node, + const Output& summed_grads) { + // TODO(skyewm): detect second-order gradient and return bad status + // TODO(skyewm): handle (or at least detect) nested while loops + + // TODO(skyewm): handle NoGradient in while loop + if (summed_grads == NoGradient()) { + return errors::Unimplemented( + "Missing gradient into while loop not yet implemented"); + } + + DCHECK(exit_node->IsExit()); + WhileContext* while_ctx = exit_node->while_ctx(); + DCHECK(while_ctx != nullptr); + + // Record 'summed_grads' as the backprop input associated with 'exit_node' + std::map& backprops = while_backprops_[while_ctx]; + DCHECK(backprops.find(exit_node) == backprops.end()); + backprops[exit_node] = summed_grads; + + // Wait until we have all exit nodes' backprops collected before processing + // the while loop. + // TODO(skyewm): what if not all the exit nodes are reachable? + if (backprops.size() < while_ctx->exit_nodes().size()) return Status::OK(); + + // We've seen all the exit nodes for this loop and have collected all the + // backprops. Create the gradient graph for the while loop. + Scope while_scope = + scope_.NewSubScope(strings::StrCat(while_ctx->frame_name(), "_grad")); + std::vector dy; + for (Node* n : while_ctx->exit_nodes()) dy.push_back(backprops[n]); + std::vector dx; + TF_RETURN_IF_ERROR(AddWhileLoopGradient(while_ctx, while_scope, dy, &dx)); + + // Backprop along the in edges to the while loop (i.e. the inputs to the enter + // nodes) + DCHECK_EQ(dx.size(), while_ctx->enter_nodes().size()); + for (int i = 0; i < dx.size(); ++i) { + Node* enter_node = while_ctx->enter_nodes()[i]; + for (const Edge* e : enter_node->in_edges()) { + if (e->IsControlEdge()) continue; + TF_RETURN_IF_ERROR(BackpropAlongEdge(dx[i], {e->src(), e->src_output()})); + } + } + return Status::OK(); +} + Status SymbolicGradientBuilder::AddGradients() { // Initialize backprops. TF_RETURN_IF_ERROR(Initialize()); @@ -346,6 +476,18 @@ Status SymbolicGradientBuilder::AddGradients() { continue; } + // Special case: if we find an exit node, process the associated while loop. + // Note that ProcessWhileLoop() calls BackpropAlongEdge() if necessary + // (which updates ready_), and we skip all the regular processing below + // after calling it. + if (n->IsExit()) { + DCHECK_EQ(dy.size(), 1); + TF_RETURN_IF_ERROR(ProcessWhileLoop(n, dy[0])); + continue; + } + // All loop-specific control flow ops should have been handled above + DCHECK(!n->IsEnter() && !n->IsNextIteration()) << n->DebugString(); + const size_t num_no_grad = no_grad_dy_indices.size(); if (IsPrimitiveOpWithNoGrad(n->type_string()) || num_no_grad == num_y) { // No grad defined for this op, or all outputs returned 'NoGradient': diff --git a/tensorflow/cc/framework/gradients_test.cc b/tensorflow/cc/framework/gradients_test.cc index dcaf10c340c61142c6f436f74285ea29a83630a9..07a062e704ed6ffc6389b5897309957a1bfcd1c2 100644 --- a/tensorflow/cc/framework/gradients_test.cc +++ b/tensorflow/cc/framework/gradients_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/client/client_session.h" #include "tensorflow/cc/framework/grad_op_registry.h" #include "tensorflow/cc/framework/testutil.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -453,6 +454,45 @@ TEST_F(GradientsTest, UnreachableInput) { " for node 'z' as it's unreachable from the output node(s)."); } +TEST_F(GradientsTest, DependentOutputs) { + auto x = Placeholder(scope_test_, DT_FLOAT); + auto y0 = Square(scope_test_, x); + auto y1 = Square(scope_test_, y0); + auto y2 = Square(scope_test_, y1); + // Requesting the gradients for y0 and y2 should return the sum of their + // individual gradients. + std::vector grad_outputs; + TF_EXPECT_OK(AddSymbolicGradients(scope_test_, {y0, y2}, {x}, &grad_outputs)); + ClientSession session(scope_test_); + std::vector grad_result; + TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 1); + EXPECT_EQ(grad_result[0].flat()(0), 17502.0f); +} + +TEST_F(GradientsTest, MultiOutputNodeDependentOutputs) { + auto x = Placeholder(scope_test_, DT_FLOAT); + auto y0 = Square(scope_test_, x); + // y1, y2, and y3 all use y0. This means the backwards pass will need to wait + // for the gradient for all three. + auto y1 = Square(scope_test_, y0); + auto y2 = Square(scope_test_, y0); + auto y3 = Square(scope_test_, y2); + std::vector grad_outputs; + // By requesting y0, y1, and y3 we test that the computation correctly waits + // for all the points in backprop where gradients need to be summed from + // multiple branches. + TF_EXPECT_OK( + AddSymbolicGradients(scope_test_, {y0, y1, y3}, {x}, &grad_outputs)); + ClientSession session(scope_test_); + std::vector grad_result; + TF_EXPECT_OK(session.Run({{x, {3.0f}}}, grad_outputs, &grad_result)); + EXPECT_EQ(grad_result.size(), 1); + EXPECT_EQ(grad_result[0].NumElements(), 1); + EXPECT_EQ(grad_result[0].flat()(0), 17610.0f); +} + // StopGradientSingleOutputMultiEdgeTest tests combinations of valid and // 'NoGradient' (induced by StopGradient op) returned along multiple edges from // a single nodes output. diff --git a/tensorflow/cc/framework/while_gradients.cc b/tensorflow/cc/framework/while_gradients.cc new file mode 100644 index 0000000000000000000000000000000000000000..0734075fc6144d7c9f4fdb48c5e097faa58b8355 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.cc @@ -0,0 +1,198 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/framework/while_gradients.h" + +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/scope_internal.h" +#include "tensorflow/cc/ops/control_flow_ops_internal.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" + +namespace tensorflow { +namespace { + +using ops::BodyGraphBuilderFn; +using ops::BuildWhileLoop; +using ops::CondGraphBuilderFn; + +Output ToOutput(OutputTensor output_tensor) { + return Output(const_cast(output_tensor.node), output_tensor.index); +} + +std::vector ToOutputVector( + const std::vector& output_tensors) { + size_t n = output_tensors.size(); + std::vector result; + result.reserve(n); + for (int i = 0; i < n; ++i) result.push_back(ToOutput(output_tensors[i])); + return result; +} + +// The backprop loop counter and main backprop loop run in their own execution +// frame (conceptually, the main forward loop and forward loop counter run +// together in a frame, then the backprop loop counter and backprop loop run +// together in a different frame). This returns the frame name to use for the +// backprop while loops. +// TODO(skyewm): make sure this is unique among existing frame names +string BackPropFrameName(const string& forward_frame_name) { + return strings::StrCat(forward_frame_name, "_backprop"); +} + +// Creates a loop that counts the number of iterations performed by the +// while loop associated with `while_ctx`. The returned output yields the +// iteration count. +Status AddForwardLoopCounter(WhileContext* while_ctx, const Scope& scope, + Output* count) { + // Create while loop: + // i = 0 + // while forward loop predicate is true: + // ++i + + Output zero = ops::Const(scope, 0, {}); + + // Condition function that returns condition output from original while loop. + CondGraphBuilderFn cond_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + Output* output) { + *output = ToOutput(while_ctx->cond_output()); + return Status::OK(); + }; + + // Body function that adds one to input. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Add(scope, inputs[0], 1)); + return scope.status(); + }; + + // Note that this loop runs in the same execution frame as the forward loop. + std::vector outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn, + while_ctx->frame_name(), &outputs, + /* create_while_ctx */ false)); + *count = outputs[0]; + return Status::OK(); +} + +// Creates a loop that executes `loop_count` times. The returned output is the +// boolean predicate indicating if the loop is still executing. This is used to +// drive the gradient computation for the while loop associated with +// `while_ctx`. +Status AddBackPropLoopCounter(WhileContext* while_ctx, const Output& loop_count, + const Scope& scope, + Output* backprop_execution_pred) { + // Create while loop: + // n = loop_count + // while n > 0: + // --n + + // Condition function that returns input > 0. + CondGraphBuilderFn cond_fn = [](const Scope& scope, + const std::vector& inputs, + Output* output) { + DCHECK_EQ(inputs.size(), 1); + *output = ops::Greater(scope, inputs[0], 0); + return scope.status(); + }; + + // Body function that subtracts one from input. + BodyGraphBuilderFn body_fn = [](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + DCHECK_EQ(inputs.size(), 1); + outputs->emplace_back(ops::Subtract(scope, inputs[0], 1)); + return scope.status(); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + std::vector outputs; + TF_RETURN_IF_ERROR(BuildWhileLoop( + scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs, + /* create_while_ctx */ false, backprop_execution_pred)); + return Status::OK(); +} + +// Creates the main backprop loop that computes the gradient of the loop +// associated with `while_ctx`. `grad_inputs` are the partial derivatives +// w.r.t. the loop outputs, i.e. the exit nodes. `backprop_execution_pred` is +// the predicate to use for the backprop loop (see AddBackPropLoopCounter()). +// The partial derivatives w.r.t. the loop inputs, i.e. the input loop vars, are +// returned in `grad_outputs`. +Status AddWhileGradientLoop(WhileContext* while_ctx, + const std::vector& grad_inputs, + const Output& backprop_execution_pred, + const Scope& parent_scope, + std::vector* grad_outputs) { + DCHECK_EQ(grad_inputs.size(), while_ctx->body_outputs().size()); + DCHECK_EQ(while_ctx->body_inputs().size(), while_ctx->body_outputs().size()); + + Scope scope = parent_scope.NewSubScope("while"); + + // Create while loop: + // while backprop_execution_pred: + // forward loop body gradient + + // Condition function that returns 'backprop_execution_pred'. + CondGraphBuilderFn cond_fn = [backprop_execution_pred]( + const Scope& scope, + const std::vector& inputs, + Output* output) { + *output = backprop_execution_pred; + return Status::OK(); + }; + + // Body function that builds while body gradient subgraph. + BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, + const std::vector& inputs, + std::vector* outputs) { + std::vector body_outputs = + ToOutputVector(while_ctx->body_outputs()); + std::vector body_inputs = ToOutputVector(while_ctx->body_inputs()); + return AddSymbolicGradients(scope, body_outputs, body_inputs, inputs, + outputs); + }; + + string frame_name = BackPropFrameName(while_ctx->frame_name()); + TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, + frame_name, grad_outputs, + /* create_while_ctx */ false)); + return Status::OK(); +} + +} // namespace + +Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + Output forward_loop_count; + TF_RETURN_IF_ERROR(AddForwardLoopCounter( + while_ctx, scope.NewSubScope("ForwardLoopCounter"), &forward_loop_count)); + + // TODO(skyewm): can we combine the backprop loop counter and main gradient + // loop into a single loop? The original Python code doesn't combine the + // loops, but I'm not sure why. + Output backprop_counter_cond; + TF_RETURN_IF_ERROR(AddBackPropLoopCounter( + while_ctx, forward_loop_count, scope.NewSubScope("BackPropLoopCounter"), + &backprop_counter_cond)); + + return AddWhileGradientLoop(while_ctx, grad_inputs, backprop_counter_cond, + scope, grad_outputs); +} + +} // namespace tensorflow diff --git a/tensorflow/cc/framework/while_gradients.h b/tensorflow/cc/framework/while_gradients.h new file mode 100644 index 0000000000000000000000000000000000000000..8f592accc93573cb8953a5ab25c04881ca0c2333 --- /dev/null +++ b/tensorflow/cc/framework/while_gradients.h @@ -0,0 +1,40 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ +#define THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ + +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/framework/scope.h" +#include "tensorflow/core/graph/while_context.h" + +// Utility functions for constructing while loop gradients + +namespace tensorflow { + +// Adds the gradient computation for the while loop associated with +// `while_ctx`. `grad_inputs` are the partial derivatives w.r.t. the loop +// outputs, i.e. the exit nodes. The partial derivatives w.r.t. the loop +// inputs, i.e. the input loop vars, are returned in `grad_outputs`. +// `grad_inputs` and `grad_outputs` are both in loop-variable order, as defined +// by the original inputs to BuildWhileLoop(). +// TODO(skyewm): maybe comment on NoGradient once it's supported +Status AddWhileLoopGradient(WhileContext* while_ctx, const Scope& scope, + const std::vector& grad_inputs, + std::vector* grad_outputs); + +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CC_FRAMEWORK_WHILE_GRADIENTS_H_ diff --git a/tensorflow/cc/framework/while_gradients_test.cc b/tensorflow/cc/framework/while_gradients_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..39fa7477c5c59e2f23b8473800f708ae65c139da --- /dev/null +++ b/tensorflow/cc/framework/while_gradients_test.cc @@ -0,0 +1,233 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/client/client_session.h" +#include "tensorflow/cc/framework/gradients.h" +#include "tensorflow/cc/framework/testutil.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/cc/ops/while_loop.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +namespace { + +class WhileGradientsTest : public ::testing::Test { + protected: + WhileGradientsTest() : scope_(Scope::NewRootScope()) {} + + void Init(int num_inputs, DataType dtype = DT_INT32) { + for (int i = 0; i < num_inputs; ++i) { + inputs_.push_back(ops::Placeholder(scope_, dtype)); + } + } + + void CreateLoop(const ops::CondGraphBuilderFn& cond, + const ops::BodyGraphBuilderFn& body, + const std::vector* inputs = nullptr) { + if (inputs == nullptr) inputs = &inputs_; + TF_ASSERT_OK(ops::BuildWhileLoop(scope_, *inputs, cond, body, "test_loop", + &outputs_)); + } + + void CreateBackprop() { + TF_ASSERT_OK( + AddSymbolicGradients(scope_, outputs_, inputs_, &grad_outputs_)); + ASSERT_EQ(grad_outputs_.size(), inputs_.size()); + } + + template + void Run(const std::vector& input_values, + const std::vector& expected_grad_values) { + Run(ClientSession(scope_), input_values, expected_grad_values); + } + + template + void Run(const ClientSession& session, + const std::vector& input_values, + const std::vector& expected_grad_values, + const RunOptions& run_options = RunOptions(), + RunMetadata* run_metadata = nullptr) { + DCHECK_EQ(input_values.size(), inputs_.size()); + ClientSession::FeedType feeds; + for (int i = 0; i < inputs_.size(); ++i) { + feeds.emplace(inputs_[i], input_values[i]); + } + + std::vector run_outputs; + std::vector out_tensors; + TF_ASSERT_OK(session.Run(run_options, feeds, grad_outputs_, run_outputs, + &out_tensors, run_metadata)); + ASSERT_EQ(out_tensors.size(), grad_outputs_.size()); + + DCHECK_EQ(expected_grad_values.size(), out_tensors.size()); + for (int i = 0; i < out_tensors.size(); ++i) { + test::ExpectTensorEqual( + out_tensors[i], test::AsTensor({expected_grad_values[i]}, {})); + } + } + + Scope scope_; + std::vector inputs_; + std::vector outputs_; + std::vector grad_outputs_; +}; + +TEST_F(WhileGradientsTest, Basic) { + // Create loop: while (i < 10) i += 1 + Init(1); + CreateLoop( + [](const Scope& s, const std::vector& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector& inputs, + std::vector* outputs) { + // Use AddN, rather than Add, because the gradient function doesn't + // depend on the input shapes, and thus we do not need to store + // intermediate values in a stack. + outputs->push_back(ops::AddN(s, {inputs[0], 1})); + return s.status(); + }); + CreateBackprop(); + + Run({1}, {1}); + Run({11}, {1}); +} + +TEST_F(WhileGradientsTest, MultipleLoopVars) { + // Create loop: while (i < 10) i += j; j += 1; k = k + Init(3); + CreateLoop( + [](const Scope& s, const std::vector& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector& inputs, + std::vector* outputs) { + outputs->push_back(ops::AddN(s, {inputs[0], inputs[1]})); + outputs->push_back(ops::AddN(s, {inputs[1], 1})); + outputs->push_back(inputs[2]); + return s.status(); + }); + CreateBackprop(); + + // The following execution traces illustrate why we expect dF/dj to be 5: + // + // i j k + // --------- + // 0 1 2 <-- initial values + // 1 2 2 + // 3 3 2 + // 6 4 2 + // 10 5 2 <-- while output values + // outputs sum = 17 + // + // i j k + // --------- + // 0 2 2 <-- initial values (add 1 to j) + // 2 3 2 + // 5 4 2 + // 9 5 2 + // 14 6 2 <-- while output values + // outputs sum = 22 + // + // Calculate the "slope" between j=1 and j=2: + // 22 - 17 = 5 => dF/dj = 5 + Run({0, 1, 2}, {1, 5, 1}); + + Run({1, 1, 0}, {1, 5, 1}); + Run({0, 0, 0}, {1, 6, 1}); +} + +TEST_F(WhileGradientsTest, Chaining) { + Init(2, DT_DOUBLE); + + // Multiply each input by 2 before passing to while loop to make sure chaining + // works properly + std::vector loop_inputs = {ops::Multiply(scope_, inputs_[0], 2.0), + ops::Multiply(scope_, inputs_[1], 2.0)}; + + // Create loop: while (i > 0 && j > 0) i -= 1 + CreateLoop( + [](const Scope& s, const std::vector& inputs, Output* output) { + *output = ops::LogicalAnd(s, ops::Greater(s, inputs[0], 0.0), + ops::Greater(s, inputs[1], 0.0)); + return s.status(); + }, + [](const Scope& s, const std::vector& inputs, + std::vector* outputs) { + outputs->push_back(ops::AddN(s, {inputs[0], -1.0})); + outputs->push_back(inputs[1]); + return s.status(); + }, + &loop_inputs); + + // Take negative of first output to make sure chaining works properly + outputs_[0] = ops::Neg(scope_, outputs_[0]); + + CreateBackprop(); + + Run({1.0, 1.0}, {-2.0, 2.0}); + Run({0.0, 0.0}, {-2.0, 2.0}); +} + +TEST_F(WhileGradientsTest, MultipleDevices) { + // Make sure loop is created on cpu0 + scope_ = scope_.WithDevice("/cpu:0"); + + // Create loop: while (i < 10) i += j + Init(2); + CreateLoop( + [](const Scope& s, const std::vector& inputs, Output* output) { + *output = ops::Less(s, inputs[0], 10); + return s.status(); + }, + [](const Scope& s, const std::vector& inputs, + std::vector* outputs) { + // Place body on cpu1 + Scope cpu1_scope = s.WithDevice("/cpu:1"); + outputs->push_back(ops::AddN(cpu1_scope, {inputs[0], inputs[1]})); + outputs->push_back(inputs[1]); + return cpu1_scope.status(); + }); + + // Build gradient graph on cpu1 + Scope cpu1_scope = scope_.WithDevice("/cpu:1"); + TF_ASSERT_OK( + AddSymbolicGradients(cpu1_scope, outputs_, inputs_, &grad_outputs_)); + ASSERT_EQ(grad_outputs_.size(), inputs_.size()); + + // Run with two CPU devices and output partition graphs + SessionOptions session_options; + (*session_options.config.mutable_device_count())["CPU"] = 2; + RunOptions run_options; + run_options.set_output_partition_graphs(true); + RunMetadata run_metadata; + Run(ClientSession(scope_, session_options), {0, 1}, {1, 11}, run_options, + &run_metadata); + + // Check that at least one node ran on each device + ASSERT_EQ(run_metadata.partition_graphs().size(), 2); + for (const GraphDef& partition_graph : run_metadata.partition_graphs()) { + EXPECT_GE(partition_graph.node().size(), 1); + } +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/cc/gradients/math_grad.cc b/tensorflow/cc/gradients/math_grad.cc index ac288b1d834d267f5bab887f45de8173e31f88ea..d7446b9560fd7dc8377ea3710641906b274313a9 100644 --- a/tensorflow/cc/gradients/math_grad.cc +++ b/tensorflow/cc/gradients/math_grad.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define _USE_MATH_DEFINES +#include + #include "tensorflow/cc/ops/array_ops_internal.h" #include "tensorflow/cc/ops/math_ops_internal.h" #include "tensorflow/cc/ops/standard_ops.h" @@ -200,8 +203,8 @@ Status TanhGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::TanhGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::TanhGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Tanh", TanhGrad); @@ -256,8 +259,8 @@ Status SigmoidGrad(const Scope& scope, const Operation& op, // evaluated. Scope grad_scope = scope.WithControlDependencies(grad); auto y = ConjugateHelper(grad_scope, op.output(0)); - grad_outputs->push_back(internal::SigmoidGrad(scope, y, grad)); - return scope.status(); + grad_outputs->push_back(internal::SigmoidGrad(grad_scope, y, grad)); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Sigmoid", SigmoidGrad); @@ -484,7 +487,7 @@ Status MaximumMinimumGradCommon(const Scope& scope, const Operation& op, auto grad = grad_inputs[0]; auto zeros = ZerosLike(scope, grad); auto gx_1 = Where3(scope, comparator, grad, zeros); - auto gx_2 = Where3(scope, LogicalNot(scope, comparator), grad, zeros); + auto gx_2 = Where3(scope, comparator, zeros, grad); return BinaryGradCommon(scope, op, grad_outputs, gx_1, gx_2); } @@ -696,15 +699,32 @@ Status MeanGrad(const Scope& scope, const Operation& op, } REGISTER_GRADIENT_OP("Mean", MeanGrad); +Status ErfGrad(const Scope& scope, const Operation& op, + const std::vector& grad_inputs, + std::vector* grad_outputs) { + auto grad = grad_inputs[0]; + auto two_over_root_pi = Cast(scope, Const(scope, 2 / std::sqrt(M_PI)), + grad.type()); + Scope grad_scope = scope.WithControlDependencies(grad); + auto x = ConjugateHelper(grad_scope, op.input(0)); + // grad * 2/sqrt(pi) * exp(-x**2) + auto dx = Mul(grad_scope, + Mul(grad_scope, grad, two_over_root_pi), + Exp(grad_scope, Neg(grad_scope, Square(grad_scope, x)))); + grad_outputs->push_back(dx); + return grad_scope.status(); +} +REGISTER_GRADIENT_OP("Erf", ErfGrad); + Status LgammaGrad(const Scope& scope, const Operation& op, const std::vector& grad_inputs, std::vector* grad_outputs) { auto grad = grad_inputs[0]; Scope grad_scope = scope.WithControlDependencies(grad); auto x = ConjugateHelper(grad_scope, op.input(0)); - auto dx = Mul(scope, grad, Digamma(scope, x)); + auto dx = Mul(grad_scope, grad, Digamma(grad_scope, x)); grad_outputs->push_back(dx); - return scope.status(); + return grad_scope.status(); } REGISTER_GRADIENT_OP("Lgamma", LgammaGrad); diff --git a/tensorflow/cc/gradients/math_grad_test.cc b/tensorflow/cc/gradients/math_grad_test.cc index a174f223ad59b0a111b3d13cb59fb2b13a0095b0..6313f41da5e5f9cf88be4c8a84408a8df77f0e25 100644 --- a/tensorflow/cc/gradients/math_grad_test.cc +++ b/tensorflow/cc/gradients/math_grad_test.cc @@ -64,7 +64,9 @@ class CWiseUnaryGradTest : public ::testing::Test { IMAG, CONJ, COMPLEX, - ANGLE + ANGLE, + LGAMMA, + ERF }; template @@ -168,6 +170,12 @@ class CWiseUnaryGradTest : public ::testing::Test { case ANGLE: y = Angle(scope_, x); break; + case LGAMMA: + y = Lgamma(scope_, x); + break; + case ERF: + y = Erf(scope_, x); + break; } float max_error; @@ -503,6 +511,42 @@ TEST_F(CWiseUnaryGradTest, Angle) { TestCWiseGrad(ANGLE, x_fn); } +TEST_F(CWiseUnaryGradTest, Lgamma) { + auto x_fn = [this](const int i) { + return RV({-3.5, -2.5, -1.5, 1.0, 2.0, 3.5}); + }; + TestCWiseGrad(LGAMMA, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Lgamma_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-3.5, 0.5}, {-1.5, -0.5}, {1.5, -1.0}, {3.5, 1.0}}); + }; + // TODO(kbsriram) + // Add test when the lgamma kernel supports complex numbers + if (false) { + TestCWiseGrad(LGAMMA, x_fn); + } +} + +TEST_F(CWiseUnaryGradTest, Erf) { + auto x_fn = [this](const int i) { + return RV({-1.2, -1.0, -0.5, 0.3, 0.5, 1.3}); + }; + TestCWiseGrad(ERF, x_fn); +} + +TEST_F(CWiseUnaryGradTest, Erf_Complex) { + auto x_fn = [this](const int i) { + return CRV({{-1.2, 0.5}, {-0.5, -0.5}, {0.5, 0.5}, {1.2, -0.5}}); + }; + // TODO(kbsriram) + // Add test when the erf kernel supports complex numbers + if (false) { + TestCWiseGrad(ERF, x_fn); + } +} + class MathGradTest : public ::testing::Test { protected: MathGradTest() : root_(Scope::NewRootScope().WithDevice("/cpu:0")) {} @@ -821,17 +865,5 @@ TEST_F(NaryGradTest, Minimum) { RunTest(x, x_init_value, y, shape); } -TEST_F(NaryGradTest, Lgamma) { - TensorShape shape({3, 2}); - auto x = Placeholder(scope_, DT_FLOAT, Placeholder::Shape(shape)); - auto y = Lgamma(scope_, x); - // Select values to avoid instability when computing finite differences. - // Ref: https://en.wikipedia.org/wiki/File:Gamma_plot.svg - Tensor x_init_value = - test::AsTensor({-3.5f, -2.5f, -1.5f, 1.0f, 2.0f, 3.5f}, {3, 2}); - RunTest(x, x_init_value, y, shape); - // TODO(suharshs): add test case for complex values -} - } // namespace } // namespace tensorflow diff --git a/tensorflow/cc/ops/const_op.cc b/tensorflow/cc/ops/const_op.cc index 0030c2b2a7b69afe2151e88ef5d6d0755f72bfa7..a04f37067dd95fc452b8e565c6bc73128c0423b2 100644 --- a/tensorflow/cc/ops/const_op.cc +++ b/tensorflow/cc/ops/const_op.cc @@ -19,19 +19,17 @@ limitations under the License. namespace tensorflow { namespace ops { -Output Const(const Scope& scope, const Input::Initializer& val) { +namespace { +template +Output ConstHelper(const Scope& scope, const T& value, DataType dtype) { if (!scope.ok()) return Output(); - if (!val.status.ok()) { - scope.UpdateStatus(val.status); - return Output(); - } Node* ret; Graph* graph = scope.graph(); const string unique_name = scope.GetUniqueNameForOp("Const"); auto builder = NodeBuilder(unique_name, "Const") - .Attr("value", val.tensor) - .Attr("dtype", val.tensor.dtype()); + .Attr("value", value) + .Attr("dtype", dtype); scope.UpdateBuilder(&builder); scope.UpdateStatus(builder.Finalize(graph, &ret)); if (!scope.ok()) return Output(); @@ -41,6 +39,19 @@ Output Const(const Scope& scope, const Input::Initializer& val) { return Output(ret); } +} // namespace + +Output Const(const Scope& scope, const Input::Initializer& val) { + if (!val.status.ok()) { + scope.UpdateStatus(val.status); + return Output(); + } + return ConstHelper(scope, val.tensor, val.tensor.dtype()); +} + +Output ConstFromProto(const Scope& scope, const TensorProto& proto) { + return ConstHelper(scope, proto, proto.dtype()); +} NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) { if (!inp.status().ok()) { diff --git a/tensorflow/cc/ops/const_op.h b/tensorflow/cc/ops/const_op.h index 516800920f282be0590ef72b26a7fdd8b92a38f9..d11fda475b3db58bf83cdb94079c8fde8d1170f7 100644 --- a/tensorflow/cc/ops/const_op.h +++ b/tensorflow/cc/ops/const_op.h @@ -28,6 +28,8 @@ namespace ops { Output Const(const Scope& scope, const Input::Initializer& val); +Output ConstFromProto(const Scope& scope, const TensorProto& proto); + NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); template diff --git a/tensorflow/cc/ops/const_op_test.cc b/tensorflow/cc/ops/const_op_test.cc index 3184edeb3307cafcbfbc41c6477fd092ab613b46..69b5d7fd47cae9b54d3e0ae42b0d3936e3c7c696 100644 --- a/tensorflow/cc/ops/const_op_test.cc +++ b/tensorflow/cc/ops/const_op_test.cc @@ -100,6 +100,20 @@ TEST(ConstOpTest, WithExplicitShape) { ExpectNodeEqual(d.node(), {"1", "2", "3", "4", "5", "6"}, {2, 3}); } +TEST(ConstOpTest, FromProto) { + Scope root = Scope::NewRootScope(); + TensorProto proto; + proto.set_dtype(DT_DOUBLE); + TensorShape({2, 2}).AsProto(proto.mutable_tensor_shape()); + for (int i = 0; i < 4; ++i) { + proto.add_double_val(static_cast(i)); + } + auto c = ops::ConstFromProto(root, proto); + TF_CHECK_OK(root.status()); + EXPECT_EQ(c.op().output_type(0), DT_DOUBLE); + ExpectNodeEqual(c.node(), {0.0, 1.0, 2.0, 3.0}, {2, 2}); +} + TEST(ConstOpTest, InvalidInitializer) { Scope root = Scope::NewRootScope(); ops::Const(root, {{2.0}, {"df"}}); diff --git a/tensorflow/cc/ops/op_gen_overrides.pbtxt b/tensorflow/cc/ops/op_gen_overrides.pbtxt index 0184c82c5afc99990530b902efdf670a2bdbc4bc..4aac990e748b0a79cbc3b353b4121a582b0883b0 100644 --- a/tensorflow/cc/ops/op_gen_overrides.pbtxt +++ b/tensorflow/cc/ops/op_gen_overrides.pbtxt @@ -11,7 +11,7 @@ op { name: "Reverse" skip: true } op { name: "ReverseV2" rename_to: "Reverse" } op { name: "Split" input_rename: { from: "split_dim" to: "axis" } } op { name: "SplitV" input_rename: { from: "split_dim" to: "axis" } } -op { name: "Squeeze" input_rename: { from: "squeeze_dims" to: "axis" } } +op { name: "Squeeze" attr_rename: { from: "squeeze_dims" to: "axis" } } op { name: "Pack" rename_to: "Stack" } op { name: "Unpack" rename_to: "Unstack" } op { name: "Select" rename_to: "Where3" input_rename: { from: "t" to: "x" } input_rename: { from: "e" to: "y" } } diff --git a/tensorflow/cc/ops/while_loop.h b/tensorflow/cc/ops/while_loop.h index 82181516d6d354fbc63c84e7846184ac87756975..a04476056a058ff0951a6347e8ffc05bc5ff5023 100644 --- a/tensorflow/cc/ops/while_loop.h +++ b/tensorflow/cc/ops/while_loop.h @@ -49,7 +49,12 @@ typedef std::function& inputs, // * outputs: output param that returns final loop variable outputs in non-error // case. Must be non-null and empty. // * create_while_ctx: if true, a WhileContext is created and populated for this -// loop. See core/graph/while_context.h for more details. +// loop. See core/graph/while_context.h for more details on +// WhileContexts. This is set to false for loops used as part of gradient +// computations, since they're part of the gradient for a loop in the +// forward-pass. +// TODO(skyewm): revisit this. Should we create WhileContexts for all loops, +// even if we don't need them? // * cond_output: if non-null, the output of the predicate is returned. This // will always be a LoopCond node. // diff --git a/tensorflow/cc/ops/while_loop_test.cc b/tensorflow/cc/ops/while_loop_test.cc index e3f6523c1905b307f3bc3e23533e9a3bae9f270b..18b8be3794f1368edcd4b0fa62432690fe4ffe24 100644 --- a/tensorflow/cc/ops/while_loop_test.cc +++ b/tensorflow/cc/ops/while_loop_test.cc @@ -146,7 +146,7 @@ TEST_F(WhileLoopTest, InvalidCondOutputIndex) { *output = {less.node(), 100}; return s.status(); }, - AddOneBody, error::INVALID_ARGUMENT, + AddOneBody, error::OUT_OF_RANGE, "Node 'cond/Less' (type: 'Less', num of outputs: 1) does not have output " "100"); } @@ -182,7 +182,7 @@ TEST_F(WhileLoopTest, InvalidBodyOutputIndex) { outputs->emplace_back(add.node(), 100); return s.status(); }, - error::INVALID_ARGUMENT, + error::OUT_OF_RANGE, "Node 'body/Add' (type: 'Add', num of outputs: 1) does not have " "output 100"); } diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 1cc7cf3f2021ede8269368aa46007b5ceaace606..d29ad3ebcbe29087d5572b51c7713e0c98d0d840 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -56,6 +56,7 @@ cc_library( ":constants", ] + if_not_mobile([ "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index fc5c6ce58d95b4ad06e48d40369740102efe0a66..ae22f7edc423247b34895411d19d7a3c21f86d4f 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -164,10 +164,6 @@ string RewriteWithName(const string& name, string code, // Generate methods for args (inputs). Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, const CompileResult& compile_result, string* methods) { - *methods += R"( - void** args() { return args_; } - const void *const *args() const { return args_; } -)"; size_t num_args = ps.parameters_size(); if (compile_result.has_context_arg) { // If the compiled function needs a XlaLocalRuntimeContext* arg, it's @@ -184,21 +180,21 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, TF_RETURN_IF_ERROR(AddRewritesForShape(i, ps.parameters(i), &rewrites)); const string code = R"( void set_arg{{NAME}}_data(void* data) { - args_[{{I}}] = data; + set_arg_data({{I}}, data); } {{TYPE}}* arg{{NAME}}_data() { - return static_cast<{{TYPE}}*>(args_[{{I}}]); + return static_cast<{{TYPE}}*>(arg_data({{I}})); } {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } const {{TYPE}}* arg{{NAME}}_data() const { - return static_cast(args_[{{I}}]); + return static_cast(arg_data({{I}})); } const {{TYPE}}& arg{{NAME}}({{DIM_VARS}}) const { return (*static_cast( - args_[{{I}}])){{INDICES}}; + arg_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -213,74 +209,33 @@ Status GenArgMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, Status GenResultMethods(const tf2xla::Config& config, const xla::ProgramShape& ps, string* methods) { if (ps.result().element_type() != xla::TUPLE) { - // Non-tuple (i.e. single-result) case. - if (config.fetch_size() != 1) { - return errors::InvalidArgument( - "non-tuple result implies 1 fetch, but got ", config.fetch_size(), - " fetches"); - } - *methods += R"( - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } -)"; - std::vector> rewrites; - TF_RETURN_IF_ERROR(AddRewritesForShape(0, ps.result(), &rewrites)); - const string code = R"( - {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>(temps_[kResultIndex]); - } - {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { - return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - temps_[kResultIndex])){{INDICES}}; - } - const {{TYPE}}* result{{NAME}}_data() const { - return static_cast(temps_[kResultIndex]); - } - const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { - return (*static_cast( - temps_[kResultIndex])){{INDICES}}; + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and we rely on this to simplify code generation. + return errors::Internal("codegen requires the XLA result to be a tuple"); } -)"; - *methods += RewriteWithName("0", code, rewrites); - if (!config.fetch(0).name().empty()) { - *methods += RewriteWithName("_" + config.fetch(0).name(), code, rewrites); - } - return Status::OK(); - } - // Tuple (i.e. multi-result) case. if (config.fetch_size() != ps.result().tuple_shapes_size()) { return errors::InvalidArgument("mismatch between fetch_size(", config.feed_size(), ") and tuple_size(", ps.result().tuple_shapes_size(), ")"); } - *methods += R"( - void** results() { - return static_cast(temps_[kResultIndex]); - } - const void *const *results() const { - return static_cast(temps_[kResultIndex]); - } -)"; for (int i = 0; i < ps.result().tuple_shapes_size(); ++i) { std::vector> rewrites; TF_RETURN_IF_ERROR( AddRewritesForShape(i, ps.result().tuple_shapes(i), &rewrites)); string code = R"( {{TYPE}}* result{{NAME}}_data() { - return static_cast<{{TYPE}}*>( - static_cast(temps_[kResultIndex])[{{I}}]); + return static_cast<{{TYPE}}*>(result_data({{I}})); } {{TYPE}}& result{{NAME}}({{DIM_VARS}}) { return (*static_cast<{{TYPE}}(*){{DIM_SIZES}}>( - static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } const {{TYPE}}* result{{NAME}}_data() const { - return static_cast<{{TYPE}}*>( - static_cast(temps_[kResultIndex])[{{I}}]); + return static_cast(result_data({{I}})); } const {{TYPE}}& result{{NAME}}({{DIM_VARS}}) const { return (*static_cast( - static_cast(temps_[kResultIndex])[{{I}}])){{INDICES}}; + result_data({{I}}))){{INDICES}}; } )"; *methods += RewriteWithName(strings::StrCat(i), code, rewrites); @@ -291,6 +246,84 @@ Status GenResultMethods(const tf2xla::Config& config, return Status::OK(); } +// Generates code implementing {Arg,Result}Names(), where T is one of +// tf2xla::{Feed,Fetch}. Each feed or fetch name results in a C-style string +// literal in the array, with nullptr terminating the array. +template +string GenNameToIndexCode(const T& entries, bool generate) { + // No need for a static array if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // Determine when to stop. We stop emitting string literals after the last + // non-empty name. + int end = entries.size(); + for (int i = entries.size() - 1; i >= 0; --i) { + if (!entries[i].name().empty()) { + break; + } + end = i; + } + // Emit string literals up to the last non-empty name. + string code = "{\n static const char* kNames[] = {"; + for (int i = 0; i < end; ++i) { + if (i > 0) { + code += ", "; + } + code += "\""; + code += entries[i].name(); + code += "\""; + } + if (end > 0) { + code += ", "; + } + code += "nullptr};\n return kNames;\n }"; + return code; +} + +// Converts the given `str` into a comma-separated list of per-character values. +string StringToCharList(const string& str) { + string list; + for (const char c : str) { + if (!list.empty()) { + list += ","; + } + list += strings::StrCat(static_cast(c)); + } + return list; +} + +string GenProgramShapeCode(xla::ProgramShape program_shape, bool generate) { + // No need for any static magic if we're not supposed to generate the data. + if (!generate) { + return "{\n return nullptr;\n }"; + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape.clear_parameter_names(); + const string proto_str = program_shape.SerializeAsString(); + // Embed the program shape as a serialized protobuf in the header file. + // + // TODO(toddw): This strategy will likely fail for larger protobufs, depending + // on the C++ compiler that is used. Figure out another solution if necessary. + string code = R"({ + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {{{PROTO_LIST}}}; + static constexpr int kProtoSize = {{PROTO_SIZE}}; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + })"; + str_util::ReplaceAllPairs( + &code, { + {"{{PROTO_LIST}}", StringToCharList(proto_str)}, + {"{{PROTO_SIZE}}", strings::StrCat(proto_str.size())}, + }); + return code; +} + Status ValidateFeedFetchCppNames(const tf2xla::Config& config) { for (const tf2xla::Feed& feed : config.feed()) { if (!feed.name().empty()) { @@ -336,24 +369,6 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, const size_t temp_bytes_total = total_buffer_bytes(itemp.data(), itemp.size()); - // Create rewrite strings for the optional context arg. - string context_include; - string context_set_arg, context_set_thread_pool, context_member_var; - string run_result = "true"; - string error_msg = "tensorflow::string()"; - if (compile_result.has_context_arg) { - // NOTE: Extra spaces and newlines are used to ensure nice formatting. - context_include = - "#include " - "\"tensorflow/compiler/tf2xla/" - "xla_local_runtime_context.h\"\n"; - context_set_arg = " args_[kNumArgs-1] = &context_;\n"; - context_set_thread_pool = " context_.thread_pool = pool;\n"; - context_member_var = " tensorflow::XlaLocalRuntimeContext context_;\n"; - run_result = "!context_.error"; - error_msg = "context_.error_msg"; - } - // Create rewrite strings for namespace start and end. string ns_start; for (const string& n : opts.namespaces) { @@ -366,6 +381,19 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, ns_end += strings::StrCat("} // end namespace ", n, "\n"); } + // Generate metadata. + const string arg_names_code = + GenNameToIndexCode(config.feed(), opts.gen_name_to_index); + const string result_names_code = + GenNameToIndexCode(config.fetch(), opts.gen_name_to_index); + const string include_xla_data_proto = + opts.gen_program_shape + ? + R"(#include "tensorflow/compiler/xla/xla_data.pb.h")" + : ""; + const string program_shape_code = + GenProgramShapeCode(ps, opts.gen_program_shape); + // Use a poor-man's text templating mechanism; first populate the full header // with placeholder tokens, and then rewrite the tokens with real values. *header = @@ -380,22 +408,23 @@ Status GenerateHeader(const HeaderOpts& opts, const tf2xla::Config& config, #ifndef TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_{{ENTRY}}_H_ // NOLINT(build/header_guard) -{{CONTEXT_INCLUDE}} -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +{{INCLUDE_XLA_DATA_PROTO}} +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void {{ENTRY}}( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); {{NS_START}} // {{CLASS}} represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // {{CLASS}} computation; // // ...set args using computation.argN methods @@ -411,9 +440,9 @@ extern "C" void {{ENTRY}}( // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: // {{PROGRAM_SHAPE}} @@ -423,7 +452,7 @@ extern "C" void {{ENTRY}}( // arg bytes aligned: {{ARG_BYTES_ALIGNED}} // temp bytes total: {{TEMP_BYTES_TOTAL}} // temp bytes aligned: {{TEMP_BYTES_ALIGNED}} -class {{CLASS}} { +class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = {{ARG_NUM}}; @@ -434,47 +463,31 @@ class {{CLASS}} { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - {{CLASS}}(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } -{{CONTEXT_SET_ARG}} - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~{{CLASS}}() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - {{CLASS}}& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); -{{CONTEXT_SET_THREAD_POOL}} - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - {{ENTRY}}(temps_[kResultIndex], &run_options_, args_, temps_); - return {{RUN_RESULT}}; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return {{ERROR_MSG}}; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = {{ENTRY}}; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = {{HAS_CONTEXT_ARG}}; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + {{CLASS}}(const {{CLASS}}&) = delete; + {{CLASS}}& operator=(const {{CLASS}}&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -493,10 +506,6 @@ class {{CLASS}} { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. {{METHODS_ARG}} // Result methods for managing output buffers. Buffers are in row-major order. @@ -511,10 +520,6 @@ class {{CLASS}} { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. {{METHODS_RESULT}} @@ -522,7 +527,7 @@ class {{CLASS}} { private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = {{TEMP_NUM}}; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = {{RESULT_INDEX}}; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -531,14 +536,14 @@ class {{CLASS}} { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; -{{CONTEXT_MEMBER_VAR}} + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() {{ARG_NAMES_CODE}} + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() {{RESULT_NAMES_CODE}} - TF_DISALLOW_COPY_AND_ASSIGN({{CLASS}}); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() {{PROGRAM_SHAPE_CODE}} }; {{NS_END}} @@ -550,22 +555,22 @@ class {{CLASS}} { const std::vector> rewrites = { {"{{ARG_BYTES_ALIGNED}}", strings::StrCat(arg_bytes_aligned)}, {"{{ARG_BYTES_TOTAL}}", strings::StrCat(arg_bytes_total)}, + {"{{ARG_NAMES_CODE}}", arg_names_code}, {"{{ARG_NUM}}", strings::StrCat(arg_sizes.size())}, {"{{ARG_SIZES}}", str_util::Join(arg_sizes, ", ")}, {"{{CLASS}}", opts.class_name}, - {"{{CONTEXT_INCLUDE}}\n", context_include}, - {"{{CONTEXT_MEMBER_VAR}}\n", context_member_var}, - {"{{CONTEXT_SET_ARG}}\n", context_set_arg}, - {"{{CONTEXT_SET_THREAD_POOL}}\n", context_set_thread_pool}, {"{{ENTRY}}", compile_result.entry_point}, - {"{{ERROR_MSG}}", error_msg}, + {"{{HAS_CONTEXT_ARG}}", + compile_result.has_context_arg ? "true" : "false"}, + {"{{INCLUDE_XLA_DATA_PROTO}}", include_xla_data_proto}, {"{{METHODS_ARG}}\n", methods_arg}, {"{{METHODS_RESULT}}\n", methods_result}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(ps)}, + {"{{PROGRAM_SHAPE_CODE}}", program_shape_code}, {"{{RESULT_INDEX}}", strings::StrCat(result_index)}, - {"{{RUN_RESULT}}", run_result}, + {"{{RESULT_NAMES_CODE}}", result_names_code}, {"{{TEMP_BYTES_ALIGNED}}", strings::StrCat(temp_bytes_aligned)}, {"{{TEMP_BYTES_TOTAL}}", strings::StrCat(temp_bytes_total)}, {"{{TEMP_NUM}}", strings::StrCat(temp_sizes.size())}, diff --git a/tensorflow/compiler/aot/codegen.h b/tensorflow/compiler/aot/codegen.h index 740edd1e83410ad0d3b854adbec20fb1cab88440..76dd0cc3cf9470a1beb2a4725724f640aecfec7f 100644 --- a/tensorflow/compiler/aot/codegen.h +++ b/tensorflow/compiler/aot/codegen.h @@ -34,6 +34,12 @@ struct HeaderOpts { // Namespaces specifies a list of C++ namespaces to add to the generated // header. If empty, all symbols will be in the global namespace. std::vector namespaces; + + // If true, generate name-to-index data for Lookup{Arg,Result}Index methods. + bool gen_name_to_index = false; + + // If true, generate program shape data for the ProgramShape method. + bool gen_program_shape = false; }; // GenerateHeader uses the meta-information from compile_result to generate a diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 98cbd67e53432e7c131c2daa27e86e3a613161a1..0f6114666fcc89c631434527d2ae8c92c039ffea 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -127,6 +127,8 @@ TEST(GenerateHeader, Golden) { HeaderOpts opts; opts.class_name = "MyClass"; opts.namespaces = {"foo", "bar"}; + opts.gen_name_to_index = true; + opts.gen_program_shape = true; tf2xla::Config config; tf2xla::Feed* feed = config.add_feed(); feed->mutable_id()->set_node_name("feed0"); @@ -145,7 +147,8 @@ TEST(GenerateHeader, Golden) { xla::ShapeUtil::MakeShape(xla::S64, {3, 4}), xla::ShapeUtil::MakeOpaqueShape(), }, - xla::ShapeUtil::MakeShape(xla::U32, {5, 6})); + xla::ShapeUtil::MakeTupleShape( + {xla::ShapeUtil::MakeShape(xla::U32, {5, 6})})); compile_result.has_context_arg = true; compile_result.entry_point = "entry_point"; compile_result.pointer_size = 8; diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden index 01963c6df4682ec8c23a93201d7fbbab63558060..65f342ce27ef09092f252f791973f245a8cdd6f3 100644 --- a/tensorflow/compiler/aot/codegen_test_h.golden +++ b/tensorflow/compiler/aot/codegen_test_h.golden @@ -9,24 +9,25 @@ #ifndef TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) #define TFCOMPILE_GENERATED_entry_point_H_ // NOLINT(build/header_guard) -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/compiler/aot/runtime.h" -#include "tensorflow/compiler/xla/executable_run_options.h" -#include "tensorflow/core/platform/macros.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" #include "tensorflow/core/platform/types.h" namespace Eigen { struct ThreadPoolDevice; } +namespace xla { class ExecutableRunOptions; } // (Implementation detail) Entry point to the function in the object file. extern "C" void entry_point( - void* result, xla::ExecutableRunOptions* run_options, - void** args, void** temps); + void* result, const xla::ExecutableRunOptions* run_options, + const void** args, void** temps); namespace foo { namespace bar { // MyClass represents a computation previously specified in a -// TensorFlow graph, now compiled into executable code. Usage example: +// TensorFlow graph, now compiled into executable code. This extends the generic +// XlaCompiledCpuFunction class with statically type-safe arg and result +// methods. Usage example: // // MyClass computation; // // ...set args using computation.argN methods @@ -42,19 +43,19 @@ namespace bar { // buffer allocation strategy. // // Under the default allocation strategy, this class is thread-compatible: -// o Calls to non-const methods require exclusive access to the object. -// o Concurrent calls to const methods are OK, if those calls are made while -// it is guaranteed that no thread may call a non-const method. +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. // // The logical function signature is: -// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> u32[5,6] +// ((unknown): f32[1,2], (unknown): s64[3,4], (unknown): opaque[]) -> (u32[5,6]) // // Memory stats: // arg bytes total: 104 // arg bytes aligned: 128 // temp bytes total: 126 // temp bytes aligned: 224 -class MyClass { +class MyClass : public tensorflow::XlaCompiledCpuFunction { public: // Number of input arguments for the compiled computation. static constexpr size_t kNumArgs = 3; @@ -65,47 +66,31 @@ class MyClass { return kArgSizes; } - // AllocMode controls the buffer allocation mode. - enum class AllocMode { - // Allocate all buffers - args, results and temps. - ARGS_RESULTS_AND_TEMPS, - - // Only allocate result and temp buffers. - // Use set_argN_data to set argument buffers before Run is called. - RESULTS_AND_TEMPS_ONLY, - }; - - MyClass(AllocMode mode = AllocMode::ARGS_RESULTS_AND_TEMPS) { - if (mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { - alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - ArgSizes(), kNumArgs, args_, false /* annotate_initialized */); - } - args_[kNumArgs-1] = &context_; - alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( - TempSizes(), kNumTemps, temps_, true /* annotate_initialized */); - } - - ~MyClass() { - tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); - tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); - } - - // Sets the thread pool to use during the Run call. - MyClass& set_thread_pool(const Eigen::ThreadPoolDevice* pool) { - run_options_.set_intra_op_thread_pool(pool); - context_.thread_pool = pool; - return *this; - } - - // Runs the computation, with inputs read from arg buffers, and outputs - // written to result buffers. Returns true on success and false on failure. - bool Run() { - entry_point(temps_[kResultIndex], &run_options_, args_, temps_); - return !context_.error; - } - - // Returns the error message from the previous failed Run call. - tensorflow::string error_msg() const { return context_.error_msg; } + // Returns static data used to create an XlaCompiledCpuFunction. + static const tensorflow::XlaCompiledCpuFunction::StaticData& StaticData() { + static XlaCompiledCpuFunction::StaticData* kStaticData = [](){ + XlaCompiledCpuFunction::StaticData* data = + new XlaCompiledCpuFunction::StaticData; + data->raw_function = entry_point; + data->arg_sizes = ArgSizes(); + data->num_args = kNumArgs; + data->temp_sizes = TempSizes(); + data->num_temps = kNumTemps; + data->result_index = kResultIndex; + data->requires_runtime_context = true; + data->arg_names = StaticArgNames(); + data->result_names = StaticResultNames(); + data->program_shape = StaticProgramShape(); + return data; + }(); + return *kStaticData; + } + + MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS) + : XlaCompiledCpuFunction(StaticData(), alloc_mode) {} + + MyClass(const MyClass&) = delete; + MyClass& operator=(const MyClass&) = delete; // Arg methods for managing input buffers. Buffers are in row-major order. // There is a set of methods for each positional argument, with the following @@ -124,66 +109,59 @@ class MyClass { // Returns a reference to the value of type T for positional argument N, // with dim indices specifying which value. No bounds checking is performed // on dim indices. - // - // void** args() - // Returns an array of argument buffers, where args()[N] is the buffer for - // positional argument N. - - void** args() { return args_; } - const void *const *args() const { return args_; } void set_arg0_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg0_data() { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } float& arg0(size_t dim0, size_t dim1) { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg0_data() const { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } const float& arg0(size_t dim0, size_t dim1) const { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg_myfeed_data(void* data) { - args_[0] = data; + set_arg_data(0, data); } float* arg_myfeed_data() { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } float& arg_myfeed(size_t dim0, size_t dim1) { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } const float* arg_myfeed_data() const { - return static_cast(args_[0]); + return static_cast(arg_data(0)); } const float& arg_myfeed(size_t dim0, size_t dim1) const { return (*static_cast( - args_[0]))[dim0][dim1]; + arg_data(0)))[dim0][dim1]; } void set_arg1_data(void* data) { - args_[1] = data; + set_arg_data(1, data); } tensorflow::int64* arg1_data() { - return static_cast(args_[1]); + return static_cast(arg_data(1)); } tensorflow::int64& arg1(size_t dim0, size_t dim1) { return (*static_cast( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } const tensorflow::int64* arg1_data() const { - return static_cast(args_[1]); + return static_cast(arg_data(1)); } const tensorflow::int64& arg1(size_t dim0, size_t dim1) const { return (*static_cast( - args_[1]))[dim0][dim1]; + arg_data(1)))[dim0][dim1]; } // Result methods for managing output buffers. Buffers are in row-major order. @@ -198,50 +176,43 @@ class MyClass { // with dim indices specifying which value. No bounds checking is performed // on dim indices. // - // void** results() - // Returns an array of result buffers, where results()[N] is the buffer for - // positional result N. - // // Unlike the arg methods, there is no set_resultN_data method. The result // buffers are managed internally, and may change after each call to Run. - void** results() { return temps_ + kResultIndex; } - const void *const *results() const { return temps_ + kResultIndex; } - tensorflow::uint32* result0_data() { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } tensorflow::uint32& result0(size_t dim0, size_t dim1) { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result0_data() const { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } const tensorflow::uint32& result0(size_t dim0, size_t dim1) const { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } tensorflow::uint32* result_myfetch_data() { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } const tensorflow::uint32* result_myfetch_data() const { - return static_cast(temps_[kResultIndex]); + return static_cast(result_data(0)); } const tensorflow::uint32& result_myfetch(size_t dim0, size_t dim1) const { return (*static_cast( - temps_[kResultIndex]))[dim0][dim1]; + result_data(0)))[dim0][dim1]; } private: // Number of result and temporary buffers for the compiled computation. static constexpr size_t kNumTemps = 6; - // The 0-based index of the result in the temporary buffers. + // The 0-based index of the result tuple in the temporary buffers. static constexpr size_t kResultIndex = 5; // Byte size of each result / temporary buffer. There are kNumTemps entries. @@ -250,14 +221,29 @@ class MyClass { return kTempSizes; } - void* args_[kNumArgs]; - void* temps_[kNumTemps]; - void* alloc_args_ = nullptr; - void* alloc_temps_ = nullptr; - xla::ExecutableRunOptions run_options_; - tensorflow::XlaLocalRuntimeContext context_; + // Array of names of each positional argument, terminated by nullptr. + static const char** StaticArgNames() { + static const char* kNames[] = {"myfeed", nullptr}; + return kNames; + } + + // Array of names of each positional result, terminated by nullptr. + static const char** StaticResultNames() { + static const char* kNames[] = {"myfetch", nullptr}; + return kNames; + } - TF_DISALLOW_COPY_AND_ASSIGN(MyClass); + // Shape of the args and results. + static const xla::ProgramShape* StaticProgramShape() { + static const xla::ProgramShape* kShape = []() { + static const char kProto[] = {10,12,16,11,26,2,1,2,42,4,10,2,1,0,10,12,16,5,26,2,3,4,42,4,10,2,1,0,10,2,16,14,18,16,16,13,34,12,16,8,26,2,5,6,42,4,10,2,1,0}; + static constexpr int kProtoSize = 50; + xla::ProgramShape* shape = new xla::ProgramShape; + shape->ParseFromArray(kProto, kProtoSize); + return shape; + }(); + return kShape; + } }; } // end namespace bar diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index eac8da0ab1b05e7d5cc8d27a1e1ffecc85515cdb..2b8cc6024cb85e4f6269313927ff66d1d9a1cf79 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -97,11 +97,15 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config, TF_RETURN_IF_ERROR(ConvertGraphDefToXla(graph_def, config, client, &computation, &compile_result->has_context_arg)); - if (!flags.debug_dir.empty()) { + if (!flags.out_session_module.empty()) { TF_ASSIGN_OR_RETURN(std::unique_ptr module, computation.Snapshot()); - string file = io::JoinPath(flags.debug_dir, "tfcompile_xla_module.pb"); - TF_RETURN_IF_ERROR(WriteBinaryProto(Env::Default(), file, *module)); + // Serialize the SessionModule deterministically so that all the outputs of + // a tf_library genrule are deterministic. + string proto; + TF_RET_CHECK(SerializeToStringDeterministic(*module, &proto)); + TF_RETURN_IF_ERROR( + WriteStringToFile(Env::Default(), flags.out_session_module, proto)); } xla::cpu::CpuAotCompilationOptions aot_opts( flags.target_triple, flags.target_cpu, flags.target_features, diff --git a/tensorflow/compiler/aot/flags.cc b/tensorflow/compiler/aot/flags.cc index 4e3998b68293aa47f028c745cea36a8c533d237d..7c2f27e550d44c2487f91acf1029c962ac3f5d01 100644 --- a/tensorflow/compiler/aot/flags.cc +++ b/tensorflow/compiler/aot/flags.cc @@ -33,9 +33,6 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "fetch nodes will be dumped to stdout in a comma-separated list. " "Typically used to format arguments for other tools, e.g. " "freeze_graph."}, - {"debug_dir", &flags->debug_dir, - "Specifies a directory to dump debugging information, including " - "rewritten graphs and the XLA HLO module."}, // Flags controlling the XLA ahead-of-time compilation, that correspond to // the fields of xla::cpu::CpuAotCompilationOptions. // @@ -64,6 +61,12 @@ void AppendMainFlags(std::vector* flag_list, MainFlags* flags) { "namespaces are given, within the global namespace."}, {"out_object", &flags->out_object, "Output object file name."}, {"out_header", &flags->out_header, "Output header file name."}, + {"out_session_module", &flags->out_session_module, + "Output session module proto."}, + {"gen_name_to_index", &flags->gen_name_to_index, + "Generate name-to-index data for Lookup{Arg,Result}Index methods."}, + {"gen_program_shape", &flags->gen_program_shape, + "Generate program shape data for the ProgramShape method."}, }; flag_list->insert(flag_list->end(), tmp.begin(), tmp.end()); } diff --git a/tensorflow/compiler/aot/flags.h b/tensorflow/compiler/aot/flags.h index e11a0173fa0035237915be80cf66b2bfca0f9b12..3519659e3af7cd345f30080a07ce91fb858623fb 100644 --- a/tensorflow/compiler/aot/flags.h +++ b/tensorflow/compiler/aot/flags.h @@ -29,7 +29,6 @@ struct MainFlags { string graph; string config; bool dump_fetch_nodes = false; - string debug_dir; string target_triple; string target_cpu; string target_features; @@ -37,6 +36,11 @@ struct MainFlags { string cpp_class; string out_object; string out_header; + string out_session_module; + + // C++ codegen options + bool gen_name_to_index = false; + bool gen_program_shape = false; }; // Appends to flag_list a tensorflow::Flag for each field in MainFlags. diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index b0b1213a847c586259e3b8f1d175f089c3961dfd..7dfd49cc3b92f83fd64ca62bd2230938ce2d0a65 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -132,6 +132,7 @@ tf_library( cpp_class = "MatMulAndAddComp", graph = "test_graph_tfmatmulandadd.pb", tags = ["manual"], + tfcompile_flags = "--gen_name_to_index --gen_program_shape", ) tf_library( @@ -156,6 +157,8 @@ tf_cc_test( ":test_graph_tfmatmul", ":test_graph_tfmatmulandadd", ":test_graph_tfsplits", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:test", "//tensorflow/core:test_main", "//third_party/eigen3", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 07562e59c8dac942f41af69c289c9f29a9767a6a..6b037f276ad1d6771b904bb970f45f32ae9531b8 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -25,6 +25,8 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/platform/test.h" namespace tensorflow { @@ -178,16 +180,6 @@ TEST(TFCompileTest, Gather) { } EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); } - - // Bad indices returns an error. - { - const float params[4] = {1, 2, 3, 4}; - std::copy(params + 0, params + 4, gather.arg0_data()); - const int32 indices[2] = {1, 4}; - std::copy(indices + 0, indices + 2, gather.arg1_data()); - EXPECT_FALSE(gather.Run()); - EXPECT_EQ(gather.error_msg(), "Invalid index for gather"); - } } TEST(TFCompileTest, MatMul2) { @@ -421,6 +413,59 @@ TEST(TFCompileTest, Splits) { EXPECT_NEAR(expected[3], fn.result0(1, 1), 1e4); } +TEST(TFCompileTest, LookupNameIndex) { + // add doesn't have any names defined in its config. + AddComp add; + EXPECT_FALSE(add.HasNameIndices()); + + // muladd has names defined for all feeds and fetches. + MatMulAndAddComp muladd; + EXPECT_TRUE(muladd.HasNameIndices()); + + EXPECT_EQ(muladd.LookupArgIndex("x"), 0); + EXPECT_EQ(muladd.LookupArgIndex("y"), 1); + EXPECT_EQ(muladd.LookupArgIndex(""), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("y_hold"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_prod"), -1); + EXPECT_EQ(muladd.LookupArgIndex("x_y_sum"), -1); + + EXPECT_EQ(muladd.LookupResultIndex("x_y_prod"), 0); + EXPECT_EQ(muladd.LookupResultIndex("x_y_sum"), 1); + EXPECT_EQ(muladd.LookupResultIndex(""), -1); + EXPECT_EQ(muladd.LookupResultIndex("x"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y"), -1); + EXPECT_EQ(muladd.LookupResultIndex("x_hold"), -1); + EXPECT_EQ(muladd.LookupResultIndex("y_hold"), -1); +} + +TEST(TFCompileTest, ProgramShape) { + using xla::ShapeUtil; + const xla::Shape f32_2x2 = ShapeUtil::MakeShape(xla::F32, {2, 2}); + + // add doesn't have the program shape defined. + AddComp add; + ASSERT_TRUE(add.ProgramShape() == nullptr); + + // muladd has the program shape defined. + MatMulAndAddComp muladd; + const xla::ProgramShape* muladd_shape = muladd.ProgramShape(); + ASSERT_TRUE(muladd_shape != nullptr); + ASSERT_EQ(muladd_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(0), f32_2x2)); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_shape->parameters(1), f32_2x2)); + + const xla::Shape& muladd_result = muladd_shape->result(); + ASSERT_EQ(muladd_result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(muladd_result), 2); + const xla::Shape& muladd_result0 = + ShapeUtil::GetTupleElementShape(muladd_result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result0, f32_2x2)); + const xla::Shape& muladd_result1 = + ShapeUtil::GetTupleElementShape(muladd_result, 1); + EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); +} + } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index 608d461a4cebba92944b8c56fd295394ba6e59b0..363d6925a14dfab8b79617449a73727ab55c4527 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -4,7 +4,7 @@ To use from your BUILD file, add the following line to load the macro: -load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") +load("@org_tensorflow//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") Then call the macro like this: @@ -16,14 +16,14 @@ tf_library( ) """ -load("//tensorflow:tensorflow.bzl", "if_android", "tf_copts") +load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_android", "tf_copts") def tf_library(name, graph, config, freeze_checkpoint=None, freeze_saver=None, cpp_class=None, gen_test=True, gen_benchmark=True, visibility=None, testonly=None, tfcompile_flags=None, - tfcompile_tool="//tensorflow/compiler/aot:tfcompile", + tfcompile_tool="@org_tensorflow//tensorflow/compiler/aot:tfcompile", include_standard_runtime_deps=True, deps=None, tags=None): """Runs tfcompile to compile a TensorFlow graph into executable code. @@ -119,9 +119,9 @@ def tf_library(name, graph, config, out_nodes_file, ] + freeze_saver_srcs, outs=[freeze_file], - cmd=("$(location //tensorflow/python/tools:freeze_graph)" + + cmd=("$(location @org_tensorflow//tensorflow/python/tools:freeze_graph)" + freeze_args), - tools=["//tensorflow/python/tools:freeze_graph"], + tools=["@org_tensorflow//tensorflow/python/tools:freeze_graph"], tags=tags, ) tfcompile_graph = freeze_file @@ -165,8 +165,38 @@ def tf_library(name, graph, config, tags=tags, ) + # Rule that runs tfcompile to produce the SessionModule proto, useful for + # debugging. TODO(b/64813587): Once the SessionModule proto is + # deterministic, move this into the main rule above. + session_module_pb = name + "_session_module.pb" + native.genrule( + name=(name + "_session_module"), + srcs=[ + tfcompile_graph, + config, + ], + outs=[ + session_module_pb, + ], + cmd=("$(location " + tfcompile_tool + ")" + + " --graph=$(location " + tfcompile_graph + ")" + + " --config=$(location " + config + ")" + + " --entry_point=" + ep + + " --cpp_class=" + cpp_class + + " --target_triple=" + target_llvm_triple() + + " --out_session_module=$(@D)/" + session_module_pb + + " " + (tfcompile_flags or "")), + tools=[tfcompile_tool], + visibility=visibility, + testonly=testonly, + local=1, + tags=tags, + ) + # The cc_library rule packaging up the header and object file, and needed # kernel implementations. + need_xla_data_proto = (tfcompile_flags and + tfcompile_flags.find("--gen_program_shape") != -1) native.cc_library( name=name, srcs=[object_file], @@ -177,23 +207,22 @@ def tf_library(name, graph, config, # These deps are required by all tf_library targets even if # include_standard_runtime_deps is False. Without them, the # generated code will fail to compile. - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/xla:executable_run_options", - "//tensorflow/core:framework_lite", - ] + (include_standard_runtime_deps and [ + "@org_tensorflow//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "@org_tensorflow//tensorflow/core:framework_lite", + ] + (need_xla_data_proto and [ + # If we're generating the program shape, we must depend on the proto. + "@org_tensorflow//tensorflow/compiler/xla:xla_data_proto", + ] or []) + (include_standard_runtime_deps and [ # TODO(cwhipkey): only depend on kernel code that the model actually needed. - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int32", - "//tensorflow/compiler/tf2xla/kernels:gather_op_kernel_float_int64", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", - "//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", - "//tensorflow/compiler/xla/service/cpu:runtime_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_matmul", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", - "//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", + "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_1d", + "@org_tensorflow//tensorflow/compiler/tf2xla/kernels:index_ops_kernel_argmax_float_2d", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_avx", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_neon", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:cpu_runtime_sse4_1", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_conv2d", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_matmul", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_conv2d", + "@org_tensorflow//tensorflow/compiler/xla/service/cpu:runtime_single_threaded_matmul", "//third_party/eigen3", ] or []) + (deps or []), tags=tags, @@ -219,12 +248,12 @@ def tf_library(name, graph, config, name=("gen_" + test_name), testonly=1, srcs=[ - "//tensorflow/compiler/aot:test.cc", + "@org_tensorflow//tensorflow/compiler/aot:test.cc", header_file, ], outs=[test_file], cmd=("sed " + sed_replace + - " $(location //tensorflow/compiler/aot:test.cc) " + + " $(location @org_tensorflow//tensorflow/compiler/aot:test.cc) " + "> $(OUTS)"), tags=tags, ) @@ -235,13 +264,13 @@ def tf_library(name, graph, config, srcs=[test_file], deps=[ ":" + name, - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/aot:tf_library_test_main", - "//tensorflow/compiler/xla:executable_run_options", + "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "@org_tensorflow//tensorflow/compiler/aot:runtime", + "@org_tensorflow//tensorflow/compiler/aot:tf_library_test_main", + "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", - "//tensorflow/core:lib", - "//tensorflow/core:test", + "@org_tensorflow//tensorflow/core:lib", + "@org_tensorflow//tensorflow/core:test", ], tags=tags, ) @@ -249,7 +278,7 @@ def tf_library(name, graph, config, if gen_benchmark: benchmark_name = name + "_benchmark" benchmark_file = benchmark_name + ".cc" - benchmark_main = ("//tensorflow/compiler/aot:" + + benchmark_main = ("@org_tensorflow//tensorflow/compiler/aot:" + "benchmark_main.template") # Rule to rewrite benchmark.cc to produce the benchmark_file. @@ -281,28 +310,27 @@ def tf_library(name, graph, config, linkopts = if_android(["-pie", "-s"]), deps=[ ":" + name, - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/compiler/aot:benchmark", - "//tensorflow/compiler/aot:runtime", - "//tensorflow/compiler/xla:executable_run_options", + "@org_tensorflow//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "@org_tensorflow//tensorflow/compiler/aot:benchmark", + "@org_tensorflow//tensorflow/compiler/aot:runtime", + "@org_tensorflow//tensorflow/compiler/xla:executable_run_options", "//third_party/eigen3", ] + if_android([ - "//tensorflow/compiler/aot:benchmark_extra_android", + "@org_tensorflow//tensorflow/compiler/aot:benchmark_extra_android", ]), tags=tags, ) - def target_llvm_triple(): """Returns the target LLVM triple to be used for compiling the target.""" # TODO(toddw): Add target_triple for other targets. For details see: # http://llvm.org/docs/doxygen/html/Triple_8h_source.html return select({ - "//tensorflow:android_armeabi": "armv5-none-android", - "//tensorflow:android_arm": "armv7-none-android", - "//tensorflow:android_arm64": "aarch64-none-android", - "//tensorflow:android_x86": "i686-none-android", - "//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", - "//tensorflow:darwin": "x86_64-none-darwin", + "@org_tensorflow//tensorflow:android_armeabi": "armv5-none-android", + "@org_tensorflow//tensorflow:android_arm": "armv7-none-android", + "@org_tensorflow//tensorflow:android_arm64": "aarch64-none-android", + "@org_tensorflow//tensorflow:android_x86": "i686-none-android", + "@org_tensorflow//tensorflow:linux_ppc64le": "ppc64le-ibm-linux-gnu", + "@org_tensorflow//tensorflow:darwin": "x86_64-none-darwin", "//conditions:default": "x86_64-pc-linux", }) diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index cc499c3284689182638665e0884f6377d8d9f3ee..6ab3d474187c7df2131f94c9f42f0d0f2f9d99d7 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -94,6 +94,8 @@ Status Main(const MainFlags& flags) { TF_RETURN_IF_ERROR(WriteStringToFile(env, flags.out_object, StringPiece(obj.data(), obj.size()))); HeaderOpts header_opts; + header_opts.gen_name_to_index = flags.gen_name_to_index; + header_opts.gen_program_shape = flags.gen_program_shape; if (flags.cpp_class.empty()) { return errors::InvalidArgument("Must specify --cpp_class"); } diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index e366db248a5d57d5d2666b82e36c3e28f6df42c0..bf7d9cf14d10f41aa48ea594a8d63db97b9973e1 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -33,6 +33,7 @@ cc_library( deps = [ ":xla_cpu_device", ":xla_cpu_jit", + "//tensorflow/compiler/plugin", ] + if_cuda_is_configured([ ":xla_gpu_device", ":xla_gpu_jit", @@ -153,7 +154,7 @@ cc_library( "//tensorflow/core:lib_internal", "//tensorflow/core:protos_all_cc", "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core/kernels:cast_op", "//tensorflow/core/kernels:constant_op", "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:identity_op", diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD index b61b3b9845e8c389e8df4d9b8f43df1f8857026a..459a582e157f5ddc63997ca93e7c0294293517d3 100644 --- a/tensorflow/compiler/jit/kernels/BUILD +++ b/tensorflow/compiler/jit/kernels/BUILD @@ -24,7 +24,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", - "//tensorflow/core:tensorflow_opensource", "//tensorflow/core/kernels:variable_ops", ], alwayslink = 1, diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc index ebfc4c07271a434a240aaccc8b809bd3c520c4af..27c5da08c112664d361b5f969d100eed7b9df65c 100644 --- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc +++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc @@ -46,7 +46,7 @@ namespace tensorflow { // see comment on `AllowsAsynchronousDeallocation()`. class XlaAllocator : public xla::DeviceMemoryAllocator { public: - XlaAllocator(const gpu::Platform* platform, OpKernelContext* op_context); + XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context); ~XlaAllocator() override; xla::StatusOr Allocate(int device_ordinal, uint64 size, bool retry_on_failure) override; @@ -80,8 +80,7 @@ class XlaAllocator : public xla::DeviceMemoryAllocator { std::unordered_map tensors_; }; -XlaAllocator::XlaAllocator(const gpu::Platform* platform, - OpKernelContext* op_context) +XlaAllocator::XlaAllocator(gpu::Platform* platform, OpKernelContext* op_context) : xla::DeviceMemoryAllocator(platform), op_context_(op_context) {} XlaAllocator::~XlaAllocator() = default; @@ -111,7 +110,7 @@ xla::StatusOr XlaAllocator::Allocate( Status XlaAllocator::RegisterArgument(const Tensor* t) { void* data = - reinterpret_cast(const_cast(t->tensor_data().data())); + reinterpret_cast(const_cast(t->tensor_data().data())); TF_RET_CHECK(data != nullptr); tensors_[data] = *t; return Status::OK(); @@ -331,7 +330,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) { return; } - output = std::move(run_result.ValueOrDie()); + output = run_result.ConsumeValueOrDie()->release(); auto elapsed = env->NowMicros() - start_time; VLOG(2) << "Elapsed time: " << elapsed << "us"; diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 0dd42f251af6e2ddfc4f162528990c2975ae5ee3..78d0aa86a8fae9a0c6035bdc579ef800337df917 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -232,10 +232,17 @@ string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, return ""; } + auto node_name = [&cycles, &graph](int node_id) { + auto* node = graph.FindNodeId(node_id); + if (node == nullptr) { + return string("(null)"); + } + return node->name(); + }; + string description; - strings::StrAppend(&description, "Edge from ", graph.FindNodeId(src)->name(), - " to ", graph.FindNodeId(dst)->name(), - " would create a cycle.\n"); + strings::StrAppend(&description, "Edge from ", node_name(src), " to ", + node_name(dst), " would create a cycle.\n"); path.resize(path_size); for (int32 node_id : path) { string ascii_art; @@ -246,8 +253,7 @@ string DescribeCycle(const GraphCycles& cycles, const Graph& graph, int src, } else { ascii_art = "+-- "; } - strings::StrAppend(&description, ascii_art, - graph.FindNodeId(node_id)->name(), "\n"); + strings::StrAppend(&description, ascii_art, node_name(node_id), "\n"); } return description; } @@ -554,6 +560,7 @@ Status MarkForCompilationPass::RunImpl( name = strings::StrCat("cluster_", cluster_sequence_num++); } n->AddAttr(kXlaClusterAttr, name); + VLOG(3) << "Assigning node " << n->name() << " to cluster " << name; } } diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 579ce415c5c3c4951be1596a37d47b7930bcf4fb..b3d258aea177fbefa4bae51d8156da2ff86c9032 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -144,8 +144,8 @@ TEST(XlaCompilationTest, UnsupportedTypes) { Node* a = ops::SourceOp( "Const", builder.opts() .WithName("A") - .WithAttr("dtype", DT_COMPLEX64) - .WithAttr("value", Tensor(DT_COMPLEX64, TensorShape()))); + .WithAttr("dtype", DT_COMPLEX128) + .WithAttr("value", Tensor(DT_COMPLEX128, TensorShape()))); Node* b = ops::UnaryOp("Neg", a, builder.opts().WithName("B")); ops::BinaryOp("MatMul", a, b, builder.opts().WithName("C")); TF_EXPECT_OK(builder.ToGraph(graph.get())); diff --git a/tensorflow/compiler/jit/xla_compilation_cache.cc b/tensorflow/compiler/jit/xla_compilation_cache.cc index b39199e16306d14f385da4730a9c4b53e163623a..23368b6c76a363882956577a20c1bd041211d234 100644 --- a/tensorflow/compiler/jit/xla_compilation_cache.cc +++ b/tensorflow/compiler/jit/xla_compilation_cache.cc @@ -312,7 +312,6 @@ Status XlaCompilationCache::Compile( *compilation_result = &entry->compilation_result; if (entry->compilation_status.ok() && executable) { if (entry->executable == nullptr) { - XlaCompiler compiler(options); entry->compilation_status = BuildExecutable( options, entry->compilation_result, &entry->executable); } diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 57b9d6b56bca23e94dc172dce2412ed151643318..e238252751e677eb947f6df03e3b2f2e948ffe19 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -39,9 +39,9 @@ Status XlaCpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Host", DEVICE_XLA_CPU, 0, - DEVICE_CPU_XLA_JIT, options, name_prefix, - &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Host", DEVICE_XLA_CPU, 0, DEVICE_CPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } @@ -50,8 +50,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kAllXlaCpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc index 02cc6654c870be980b902c869062c60244b97025..d4d8fe1c1d575b4e35d624621cc709e3a16569d5 100644 --- a/tensorflow/compiler/jit/xla_device.cc +++ b/tensorflow/compiler/jit/xla_device.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/dma_helper.h" #include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/renamed_device.h" #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/device_base.h" #include "tensorflow/core/framework/function.h" @@ -107,18 +108,21 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator( /* static */ Status XlaDevice::Create( const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, - const string& name_prefix, std::unique_ptr* device) { + const string& name_prefix, bool register_device_for_compilation, + std::unique_ptr* device) { VLOG(1) << "XlaDevice::Create " << platform_name << " " << device_name << ":" << device_ordinal; - // These are no-ops if they have already been done previously for - // this device_name/compilation_device_name pair. - XlaOpRegistry::DeviceRegistration registration; - registration.compilation_device_name = jit_device_name; - registration.requires_compilation = true; - registration.enable_jit_by_default = false; - registration.compile_resource_ops = true; - XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + if (register_device_for_compilation) { + // These are no-ops if they have already been done previously for + // this device_name/compilation_device_name pair. + XlaOpRegistry::DeviceRegistration registration; + registration.compilation_device_name = jit_device_name; + registration.requires_compilation = true; + registration.enable_jit_by_default = false; + registration.compile_resource_ops = true; + XlaOpRegistry::RegisterCompilationDevice(device_name, registration); + } auto platform = se::MultiPlatformManager::PlatformWithName(platform_name); if (!platform.ok()) { @@ -158,7 +162,8 @@ const DeviceType& XlaDevice::Metadata::jit_device_type() const { /* static */ Status XlaDevice::GetMetadata(OpKernelContext* ctx, const Metadata** metadata) { - XlaDevice* xla_device = dynamic_cast(ctx->device()); + XlaDevice* xla_device = + dynamic_cast(ctx->device()->UnderlyingDevice()); if (xla_device == nullptr) { return errors::Internal( "Cannot get XLA metadata from non-XLA device \"", ctx->device()->name(), @@ -236,7 +241,8 @@ void XlaDevice::Compute(OpKernel* op_kernel, OpKernelContext* context) { // When TraceMe profiling is off (which is the default), the // following TraceMe constructor is simply a conditional test of // false value. Measurements show that its overhead is negligible. - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->Compute(context); } @@ -244,7 +250,8 @@ void XlaDevice::ComputeAsync(AsyncOpKernel* op_kernel, OpKernelContext* context, AsyncOpKernel::DoneCallback done) { VLOG(1) << "XlaDevice::ComputeAsync " << op_kernel->name() << ":" << op_kernel->type_string(); - port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string()); + port::Tracing::TraceMe trace_me(op_kernel->name(), op_kernel->type_string(), + op_kernel->IsExpensive()); op_kernel->ComputeAsync(context, done); } @@ -286,7 +293,9 @@ XlaDeviceOpRegistrations* RegisterXlaDeviceKernels(const char* device, auto dummy_factory = [](OpKernelConstruction* context) -> OpKernel* { return new XlaDeviceDummyOp(context); }; - for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels(jit_device)) { + for (const KernelDef* jit_def : XlaOpRegistry::DeviceKernels( + jit_device, + /*include_compilation_only_kernels=*/false)) { KernelDef* def = new KernelDef(*jit_def); def->set_device_type(device); registrations->op_kernel_registrars.emplace_back( diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h index 0d90b8b692896d8addf5ffead3980a5bf640c85c..d2ec38293c429f04f088bf3726ba97eb4e4b0dba 100644 --- a/tensorflow/compiler/jit/xla_device.h +++ b/tensorflow/compiler/jit/xla_device.h @@ -74,6 +74,7 @@ class XlaDevice : public LocalDevice { static Status Create(const string& platform_name, const string& device_name, int device_ordinal, const string& jit_device_name, const SessionOptions& options, const string& name_prefix, + bool register_device_for_compilation, std::unique_ptr* device); XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs, diff --git a/tensorflow/compiler/jit/xla_device_ops.h b/tensorflow/compiler/jit/xla_device_ops.h index 8699006ebc5aacafd46046a7c3f093356f687280..498d25cf566a91f68e5eb1ac312e17900471aeca 100644 --- a/tensorflow/compiler/jit/xla_device_ops.h +++ b/tensorflow/compiler/jit/xla_device_ops.h @@ -20,6 +20,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/constant_op.h" #include "tensorflow/core/kernels/control_flow_ops.h" #include "tensorflow/core/kernels/identity_op.h" @@ -53,6 +54,9 @@ class XlaDeviceDummyOp : public OpKernel { Name("_HostSend").Device(DEVICE).HostMemory("tensor"), SendOp); \ REGISTER_KERNEL_BUILDER( \ Name("_HostRecv").Device(DEVICE).HostMemory("tensor"), RecvOp); \ + REGISTER_KERNEL_BUILDER( \ + Name("_HostCast").Device(DEVICE).HostMemory("x").HostMemory("y"), \ + CpuCastOp); \ REGISTER_KERNEL_BUILDER(Name("NoOp").Device(DEVICE), NoOp); \ REGISTER_KERNEL_BUILDER( \ Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES), \ diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index 4474d8f4eb06afa78ea36332a8cc58f9d240c1b0..2326070358d67c0cf30ef17fab5c93862cd8932c 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -39,9 +39,9 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options, (void)registrations; std::unique_ptr device; - Status status = - XlaDevice::Create("CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, - name_prefix, &device); + Status status = XlaDevice::Create( + "CUDA", DEVICE_XLA_GPU, 0, DEVICE_GPU_XLA_JIT, options, name_prefix, + /*register_device_for_compilation=*/true, &device); if (!status.ok()) { // Treat failures as non-fatal; there might not be a GPU in the machine. VLOG(1) << "Failed to create XLA_GPU device: " << status; @@ -55,8 +55,8 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kAllXlaGpuTypes = { + {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); diff --git a/tensorflow/compiler/jit/xla_interpreter_device.cc b/tensorflow/compiler/jit/xla_interpreter_device.cc index 4e4cbe200a21b6584f0fefb4cf43874fb213e244..2614deefd8823dcb8f38e9e22ae4e78145d0d96a 100644 --- a/tensorflow/compiler/jit/xla_interpreter_device.cc +++ b/tensorflow/compiler/jit/xla_interpreter_device.cc @@ -42,9 +42,9 @@ Status XlaInterpreterDeviceFactory::CreateDevices( (void)registrations; std::unique_ptr device; - TF_RETURN_IF_ERROR(XlaDevice::Create("Interpreter", DEVICE_XLA_INTERPRETER, 0, - DEVICE_INTERPRETER_XLA_JIT, options, - name_prefix, &device)); + TF_RETURN_IF_ERROR(XlaDevice::Create( + "Interpreter", DEVICE_XLA_INTERPRETER, 0, DEVICE_INTERPRETER_XLA_JIT, + options, name_prefix, /*register_device_for_compilation=*/true, &device)); devices->push_back(device.release()); return Status::OK(); } diff --git a/tensorflow/compiler/plugin/BUILD b/tensorflow/compiler/plugin/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c1edf2448c54ffddd7b70dcdfb1609080ca81b65 --- /dev/null +++ b/tensorflow/compiler/plugin/BUILD @@ -0,0 +1,56 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Configuration file for an XLA plugin. + + please don't check in changes to this file. to prevent changes appearing + in git status, use: + + git update-index --assume-unchanged tensorflow/compiler/plugin/BUILD + + To add additional devices to the XLA subsystem, add targets to the + dependency list in the 'plugin' target. For instance: + + deps = ["//tensorflow/compiler/plugin/example:plugin_lib"], + + ** Please don't remove this file - it is supporting some 3rd party plugins ** +""" + +licenses(["notice"]) + +package( + default_visibility = ["//visibility:public"], +) + +cc_library( + name = "plugin", + deps = [ + #"//tensorflow/compiler/plugin/example:example_lib", + ], +) + +#----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/plugin/README.md b/tensorflow/compiler/plugin/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9dd0d2bdab5e2c990fd547cef4b657253c545715 --- /dev/null +++ b/tensorflow/compiler/plugin/README.md @@ -0,0 +1,16 @@ +3rd party XLA devices +--------------------- + +This directory is intended as a place for 3rd party XLA devices which are _not_ +integrated into the public repository. + +By adding entries to the BUILD target in this directory, a third party device +can be included as a dependency of the JIT subsystem. + +For integration into the unit test system, see the files: + +- tensorflow/compiler/tests/plugin.bzl +- tensorflow/compiler/xla/tests/plugin.bzl + + +- diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a54d1f54f9533509534505228e66315f78f1bbfa..21b88239445d3169572abecada62fa9c5ceba4c7 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -23,6 +23,10 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test") load("//tensorflow/compiler/aot:tfcompile.bzl", "tf_library") load("//tensorflow/compiler/tests:build_defs.bzl", "tf_xla_py_test") load("//tensorflow/compiler/tests:build_defs.bzl", "generate_backend_suites") +load( + "//tensorflow/core:platform/default/build_config_root.bzl", + "tf_cuda_tests_tags", +) generate_backend_suites() @@ -75,14 +79,35 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "argminmax_test", + size = "small", + srcs = ["argminmax_test.py"], + # ArgMax needs CustomCall on CPU, which is not available in normal + # (not precompiled) TensorFlow. The flag below excludes the CPU + # backend. + disabled_backends = "cpu", + deps = [ + ":xla_test", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + tf_xla_py_test( name = "binary_ops_test", size = "small", srcs = ["binary_ops_test.py"], shard_count = 5, + tags = [ + "optonly", # Times out frequently in fastbuild mode. + ], deps = [ ":xla_test", "//tensorflow/python:array_ops", + "//tensorflow/python:bitwise_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:math_ops_gen", @@ -92,6 +117,18 @@ tf_xla_py_test( ], ) +tf_xla_py_test( + name = "categorical_op_test", + size = "small", + srcs = ["categorical_op_test.py"], + deps = [ + ":xla_test", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + ], +) + tf_xla_py_test( name = "clustering_test", size = "small", @@ -162,6 +199,7 @@ tf_xla_py_test( "noasan", "nomsan", "notsan", + "optonly", # Times out frequently in fastbuild mode. ], deps = [ ":xla_test", @@ -191,11 +229,6 @@ tf_xla_py_test( name = "slice_ops_test", size = "small", srcs = ["slice_ops_test.py"], - # TODO(b/62962492): Test fails with assertion error. - tags = [ - "manual", - "notap", - ], deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -443,7 +476,7 @@ tf_xla_py_test( tf_xla_py_test( name = "unary_ops_test", - size = "small", + size = "medium", srcs = ["unary_ops_test.py"], deps = [ ":xla_test", @@ -492,12 +525,8 @@ tf_xla_py_test( tf_xla_py_test( name = "gather_test", - size = "small", + size = "medium", srcs = ["gather_test.py"], - # Gather needs CustomCall on CPU, which is not available in normal - # (not precompiled) TensorFlow. The flag below excludes the CPU - # backend. - disabled_backends = "cpu", deps = [ ":xla_test", "//tensorflow/python:array_ops", @@ -559,6 +588,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", "//tensorflow/core:test", "//tensorflow/core:testlib", "//tensorflow/core/kernels:ops_util", @@ -567,11 +597,12 @@ cc_library( tf_cuda_cc_test( name = "randomized_tests", + size = "large", # This test is randomized, so only run it if explicitly requested. tags = [ "manual", "notap", - ], + ] + tf_cuda_tests_tags(), deps = [":randomized_tests_library"], ) diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ec547e16cd9c91a1e25bc963b9a3cafddf7326cd --- /dev/null +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -0,0 +1,80 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functional tests for ArgMin and ArgMax Ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.compiler.tests import xla_test +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class ArgMinMaxTest(xla_test.XLATestCase): + + def _assertOpOutputMatchesExpected(self, op, inp, expected): + """Verifies that 'op' produces 'expected' when fed input 'inp' . + + Args: + op: operator to test + inp: numpy input array to use as input to 'op'. + expected: numpy array representing the expected output of 'op'. + """ + with self.test_session() as session: + with self.test_scope(): + pinp = array_ops.placeholder( + dtypes.as_dtype(inp.dtype), inp.shape, name="a") + output = op(pinp) + result = session.run(output, {pinp: inp}) + self.assertAllEqual(result, expected) + + def testArgMinMax(self): + # Complex numbers do not support argmin/argmax. + minmax_types = set(self.numeric_types) - set(self.complex_types) + for dtype in minmax_types: + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), + np.array([1, 10, 27, 3, 3, 4], dtype=dtype), + expected=np.int32(2)) + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmax(x, axis=0, output_type=dtypes.int32), + np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([0, 1, 0], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmax(x, axis=1, output_type=dtypes.int32), + np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([0, 0], dtype=np.int32)) + + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), + np.array([3, 10, 27, 3, 2, 4], dtype=dtype), + expected=np.int32(4)) + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmin(x, axis=0, output_type=dtypes.int32), + np.array([[4, 1, 7], [3, 2, 4]], dtype=dtype), + expected=np.array([1, 0, 1], dtype=np.int32)) + self._assertOpOutputMatchesExpected( + lambda x: math_ops.argmin(x, axis=1, output_type=dtypes.int32), + np.array([[4, 1], [3, 2]], dtype=dtype), + expected=np.array([1, 1], dtype=np.int32)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index e6862f0d9dd7ec05b4e0c4ba26ab5f16a7aa9ad7..d412c572ae16b84c2434819aa0a2d881defef5f9 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -22,7 +22,9 @@ import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops @@ -44,6 +46,10 @@ class BinaryOpsTest(XLATestCase): equality_test = self.assertAllClose equality_test(result, expected, rtol=1e-3) + def _testSymmetricBinary(self, op, a, b, expected, equality_test=None): + self._testBinary(op, a, b, expected, equality_test) + self._testBinary(op, b, a, expected, equality_test) + def ListsAreClose(self, result, expected, rtol): """Tests closeness of two lists of floats.""" self.assertEqual(len(result), len(expected)) @@ -88,6 +94,15 @@ class BinaryOpsTest(XLATestCase): dtype(4), expected=np.array([[16], [81]], dtype=dtype)) + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._testBinary( + math_ops.atan2, + np.array([0, np.sqrt(2), 1, np.sqrt(2), 0], dtype), + np.array([1, np.sqrt(2), 0, -np.sqrt(2), -1], dtype), + expected=np.array( + [0, np.pi / 4, np.pi / 2, np.pi * 3 / 4, np.pi], dtype=dtype)) + self._testBinary( gen_math_ops._reciprocal_grad, np.array([4, -3, -2, 1], dtype=dtype), @@ -192,6 +207,32 @@ class BinaryOpsTest(XLATestCase): np.array([3, 3, -1, -9, -8], dtype=dtype), np.array([2, -2, 7, 2, -4], dtype=dtype), expected=np.array([1, -1, 0, -4, 2], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_and, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b0, 0b101, 0b1000], dtype=dtype)) + self._testSymmetricBinary( + bitwise_ops.bitwise_or, + np.array([0b1, 0b101, 0b1000], dtype=dtype), + np.array([0b0, 0b101, 0b1001], dtype=dtype), + expected=np.array([0b1, 0b101, 0b1001], dtype=dtype)) + + lhs = np.array([0, 5, 3, 14], dtype=dtype) + rhs = np.array([5, 0, 7, 11], dtype=dtype) + self._testBinary( + bitwise_ops.left_shift, lhs, rhs, + expected=np.left_shift(lhs, rhs)) + self._testBinary( + bitwise_ops.right_shift, lhs, rhs, + expected=np.right_shift(lhs, rhs)) + + if dtype in [np.int8, np.int16, np.int32, np.int64]: + lhs = np.array([-1, -5, -3, -14], dtype=dtype) + rhs = np.array([5, 0, 1, 11], dtype=dtype) + self._testBinary( + bitwise_ops.right_shift, lhs, rhs, + expected=np.right_shift(lhs, rhs)) def testNumericOps(self): for dtype in self.numeric_types: @@ -227,37 +268,38 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([10, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([5, 20], dtype=dtype)) - self._testBinary( - math_ops.maximum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[10], [7]], dtype=dtype)) - - self._testBinary( - math_ops.minimum, - np.array([1, 20], dtype=dtype), - np.array([10, 2], dtype=dtype), - expected=np.array([1, 2], dtype=dtype)) - self._testBinary( - math_ops.minimum, - dtype(5), - np.array([1, 20], dtype=dtype), - expected=np.array([1, 5], dtype=dtype)) - self._testBinary( - math_ops.minimum, - np.array([[10], [2]], dtype=dtype), - dtype(7), - expected=np.array([[7], [2]], dtype=dtype)) + if dtype not in self.complex_types: # min/max not supported for complex + self._testBinary( + math_ops.maximum, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([10, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([5, 20], dtype=dtype)) + self._testBinary( + math_ops.maximum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[10], [7]], dtype=dtype)) + + self._testBinary( + math_ops.minimum, + np.array([1, 20], dtype=dtype), + np.array([10, 2], dtype=dtype), + expected=np.array([1, 2], dtype=dtype)) + self._testBinary( + math_ops.minimum, + dtype(5), + np.array([1, 20], dtype=dtype), + expected=np.array([1, 5], dtype=dtype)) + self._testBinary( + math_ops.minimum, + np.array([[10], [2]], dtype=dtype), + dtype(7), + expected=np.array([[7], [2]], dtype=dtype)) self._testBinary( math_ops.multiply, @@ -275,21 +317,23 @@ class BinaryOpsTest(XLATestCase): dtype(7), expected=np.array([[70], [14]], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([1, 2], dtype=dtype), - np.array([10, 20], dtype=dtype), - expected=np.array([81, 324], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - dtype(5), - np.array([1, 2], dtype=dtype), - expected=np.array([16, 9], dtype=dtype)) - self._testBinary( - math_ops.squared_difference, - np.array([[1], [2]], dtype=dtype), - dtype(7), - expected=np.array([[36], [25]], dtype=dtype)) + # Complex support for squared_difference is incidental, see b/68205550 + if dtype not in self.complex_types: + self._testBinary( + math_ops.squared_difference, + np.array([1, 2], dtype=dtype), + np.array([10, 20], dtype=dtype), + expected=np.array([81, 324], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + dtype(5), + np.array([1, 2], dtype=dtype), + expected=np.array([16, 9], dtype=dtype)) + self._testBinary( + math_ops.squared_difference, + np.array([[1], [2]], dtype=dtype), + dtype(7), + expected=np.array([[36], [25]], dtype=dtype)) self._testBinary( nn_ops.bias_add, @@ -302,6 +346,139 @@ class BinaryOpsTest(XLATestCase): np.array([2, -1], dtype=dtype), expected=np.array([[[[3, 1], [5, 3]]]], dtype=dtype)) + def testComplexOps(self): + for dtype in self.complex_types: + ctypes = {np.complex64: np.float32} + self._testBinary( + math_ops.complex, + np.array([[[[-1, 2], [2, 0]]]], dtype=ctypes[dtype]), + np.array([[[[2, -3], [0, 4]]]], dtype=ctypes[dtype]), + expected=np.array([[[[-1 + 2j, 2 - 3j], [2, 4j]]]], dtype=dtype)) + + self._testBinary( + lambda x, y: math_ops.approximate_equal(x, y, tolerance=0.0001), + np.array( + [[[[-1 + 2j, 2.00009999 - 3j], [2 - 3j, 3 + 4.01j]]]], + dtype=dtype), + np.array( + [[[[-1.001 + 2j, 2 - 3j], [2 - 3.00009j, 3 + 4j]]]], dtype=dtype), + expected=np.array([[[[False, True], [True, False]]]], dtype=dtype)) + + self._testBinary( + gen_math_ops._real_div, + np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype), + np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype), + expected=np.array( + [ + 1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2, + float("inf") + ], + dtype=dtype)) + + # TODO(b/65408531): support+test pow for cplx + + lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype) + rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype) + self._testBinary( + gen_math_ops._reciprocal_grad, lhs, rhs, expected=-rhs * lhs * lhs) + + self._testBinary( + gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs)) + + # TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow) + + self._testBinary( + gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs)) + + self._testBinary( + gen_math_ops._tanh_grad, lhs, rhs, expected=rhs * (1 - lhs * lhs)) + + def testComplexMath(self): + for dtype in self.complex_types: + self._testBinary( + math_ops.add, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([11 - 1j, 22 + 24j], dtype=dtype)) + self._testBinary( + math_ops.add, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([6 - 5j, 7 - 3j], dtype=dtype)) + self._testBinary( + math_ops.add, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[8 + 3j], [9 + 6j]], dtype=dtype)) + + self._testBinary( + math_ops.subtract, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array([-9 + 7j, -18 - 10j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array([4 - 9j, 3 - 11j], dtype=dtype)) + self._testBinary( + math_ops.subtract, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array([[-6 - 7j], [-5 - 4j]], dtype=dtype)) + + self._testBinary( + math_ops.multiply, + np.array([1 + 3j, 2 + 7j], dtype=dtype), + np.array([10 - 4j, 20 + 17j], dtype=dtype), + expected=np.array( + [(1 + 3j) * (10 - 4j), (2 + 7j) * (20 + 17j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + dtype(5 - 7j), + np.array([1 + 2j, 2 + 4j], dtype=dtype), + expected=np.array( + [(5 - 7j) * (1 + 2j), (5 - 7j) * (2 + 4j)], dtype=dtype)) + self._testBinary( + math_ops.multiply, + np.array([[1 - 2j], [2 + 1j]], dtype=dtype), + dtype(7 + 5j), + expected=np.array( + [[(7 + 5j) * (1 - 2j)], [(7 + 5j) * (2 + 1j)]], dtype=dtype)) + + self._testBinary( + math_ops.div, + np.array([8 - 1j, 2 + 16j], dtype=dtype), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(8 - 1j) / (2 + 4j), (2 + 16j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + dtype(1 + 2j), + np.array([2 + 4j, 4 - 8j], dtype=dtype), + expected=np.array( + [(1 + 2j) / (2 + 4j), (1 + 2j) / (4 - 8j)], dtype=dtype)) + self._testBinary( + math_ops.div, + np.array([2 + 4j, 4 - 8j], dtype=dtype), + dtype(1 + 2j), + expected=np.array( + [(2 + 4j) / (1 + 2j), (4 - 8j) / (1 + 2j)], dtype=dtype)) + + # TODO(b/68205550): math_ops.squared_difference shouldn't be supported. + + self._testBinary( + nn_ops.bias_add, + np.array([[1 + 2j, 2 + 7j], [3 - 5j, 4 + 2j]], dtype=dtype), + np.array([2 + 6j, -1 - 3j], dtype=dtype), + expected=np.array([[3 + 8j, 1 + 4j], [5 + 1j, 3 - 1j]], dtype=dtype)) + self._testBinary( + nn_ops.bias_add, + np.array([[[[1 + 4j, 2 - 1j], [3 + 7j, 4]]]], dtype=dtype), + np.array([2 + 1j, -1 + 2j], dtype=dtype), + expected=np.array( + [[[[3 + 5j, 1 + 1j], [5 + 8j, 3 + 2j]]]], dtype=dtype)) + def _testDivision(self, dtype): """Test cases for division operators.""" self._testBinary( @@ -320,18 +497,19 @@ class BinaryOpsTest(XLATestCase): dtype(2), expected=np.array([[5], [2]], dtype=dtype)) - self._testBinary( - gen_math_ops._floor_div, - np.array([3, 3, -1, -9, -8], dtype=dtype), - np.array([2, -2, 7, 2, -4], dtype=dtype), - expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) + if dtype not in self.complex_types: # floordiv unsupported for complex. + self._testBinary( + gen_math_ops._floor_div, + np.array([3, 3, -1, -9, -8], dtype=dtype), + np.array([2, -2, 7, 2, -4], dtype=dtype), + expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): for dtype in self.int_types: self._testDivision(dtype) def testFloatDivision(self): - for dtype in self.float_types: + for dtype in self.float_types + self.complex_types: self._testDivision(dtype) def _testRemainder(self, dtype): @@ -675,6 +853,20 @@ class BinaryOpsTest(XLATestCase): [0, 0, 0, 0, 0, 0]], dtype=dtype)) + self._testBinary( + lambda x, y: array_ops.pad(x, y, constant_values=7), + np.array( + [[1, 2, 3], [4, 5, 6]], dtype=dtype), + np.array( + [[0, 3], [2, 1]], dtype=np.int32), + expected=np.array( + [[7, 7, 1, 2, 3, 7], + [7, 7, 4, 5, 6, 7], + [7, 7, 7, 7, 7, 7], + [7, 7, 7, 7, 7, 7], + [7, 7, 7, 7, 7, 7]], + dtype=dtype)) + def testMirrorPad(self): mirror_pad = lambda t, paddings: array_ops.pad(t, paddings, "REFLECT") for dtype in self.numeric_types: @@ -789,28 +981,30 @@ class BinaryOpsTest(XLATestCase): def testSplit(self): for dtype in self.numeric_types: - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), - np.int32(0), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1], [2]]], dtype=dtype), - np.array([[[3], [4]]], dtype=dtype), - np.array([[[5], [6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) - - self._testBinary( - lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), - np.int32(1), - np.array([[[1], [2]], [[3], [4]], [[5], [6]]], - dtype=dtype), - expected=[ - np.array([[[1]], [[3]], [[5]]], dtype=dtype), - np.array([[[2]], [[4]], [[6]]], dtype=dtype), - ], - equality_test=self.ListsAreClose) + for axis in [0, -3]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=3, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1], [2]]], dtype=dtype), + np.array([[[3], [4]]], dtype=dtype), + np.array([[[5], [6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) + + for axis in [1, -2]: + self._testBinary( + lambda x, y: array_ops.split(value=y, num_or_size_splits=2, axis=x), + np.int32(axis), + np.array([[[1], [2]], [[3], [4]], [[5], [6]]], + dtype=dtype), + expected=[ + np.array([[[1]], [[3]], [[5]]], dtype=dtype), + np.array([[[2]], [[4]], [[6]]], dtype=dtype), + ], + equality_test=self.ListsAreClose) def testTile(self): for dtype in self.numeric_types: @@ -890,6 +1084,64 @@ class BinaryOpsTest(XLATestCase): np.array([[4, 5, 6], [40, 50, 60]], dtype=dtype), expected=np.array([[-3, 6, -3], [60, -120, 60]], dtype=dtype)) + def testBroadcastArgs(self): + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([2, 3, 5], dtype=np.int32), + np.array([1], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([1], dtype=np.int32), + np.array([2, 3, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([2, 3, 5], dtype=np.int32), + np.array([5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([5], dtype=np.int32), + np.array([2, 3, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([2, 3, 5], dtype=np.int32), + np.array([3, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([3, 5], dtype=np.int32), + np.array([2, 3, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([2, 3, 5], dtype=np.int32), + np.array([3, 1], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([3, 1], dtype=np.int32), + np.array([2, 3, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([2, 1, 5], dtype=np.int32), + np.array([3, 1], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([3, 1], dtype=np.int32), + np.array([2, 1, 5], dtype=np.int32), + expected=np.array([2, 3, 5], dtype=np.int32)) + + with self.assertRaisesWithPredicateMatch(errors.InvalidArgumentError, + "Incompatible shapes"): + self._testBinary(array_ops.broadcast_dynamic_shape, + np.array([1, 2, 3], dtype=np.int32), + np.array([4, 5, 6], dtype=np.int32), + expected=None) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index a56c53de0fb5f76c94064e2bdc2f1a543a207b09..0528a5415d579a844e68403ace1bb8982a10a841 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -49,11 +49,15 @@ def tf_xla_py_test(name, srcs=[], deps=[], tags=[], data=[], main=None, backend_deps = [] backend_data = [] if backend == "cpu": - backend_args += ["--test_device=XLA_CPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_args += [ + "--test_device=XLA_CPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + ] elif backend == "gpu": - backend_args += ["--test_device=XLA_GPU", - "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL"] + backend_args += [ + "--test_device=XLA_GPU", + "--types=DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64" + ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: backend_args += ["--test_device=" + plugins[backend]["device"], diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5e06f9a72401935b9681c35a164b51f50a8538ae --- /dev/null +++ b/tensorflow/compiler/tests/categorical_op_test.py @@ -0,0 +1,135 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for multinomial generation ops in the XLA JIT compiler.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +import numpy as np + +from tensorflow.compiler.tests.xla_test import XLATestCase +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import googletest + + +# TODO(srvasude): Merge this with +# third_party/tensorflow/python/kernel_tests/random/multinomial_op_test.py. +class CategoricalTest(XLATestCase): + """Test cases for random-number generating operators.""" + + def _chi2(self, expected, actual): + """Returns Chi2 GOF statistic.""" + actual = np.asarray(actual) + expected = np.asarray(expected) + diff = actual - expected + chi2 = np.sum(diff * diff / expected) + return chi2 + + def _do_sampling(self, logits, num_samples): + """Categorical samples from given input. + + Args: + logits: Numpy ndarray of shape [batch_size, num_classes]. + num_samples: Int; number of samples to draw. + + Returns: + Frequencies from sampled classes; shape [batch_size, num_classes]. + """ + with self.test_session() as sess, self.test_scope(): + random_seed.set_random_seed(1618) + op = random_ops.multinomial(logits, num_samples) + d = sess.run(op) + + batch_size, num_classes = logits.shape + freqs_mat = [] + for i in range(batch_size): + cnts = dict(collections.Counter(d[i, :])) + + # Requires drawn class labels be in range. + self.assertLess(max(cnts.keys()), num_classes) + self.assertGreaterEqual(min(cnts.keys()), 0) + + freqs = [(cnts[k] * 1. / num_samples if k in cnts else 0) + for k in range(num_classes)] + freqs_mat.append(freqs) + + return freqs_mat + + def _testRngIsNotConstant(self, rng, dtype): + # Tests that 'rng' does not always return the same value. + with self.test_session() as sess: + with self.test_scope(): + x = rng(dtype) + + # The random-number generator, if working correctly, should produce the + # same output multiple times with low probability. + y = sess.run(x) + z = sess.run(x) + w = sess.run(x) + + # We use exact equality here. If the random-number generator is producing + # deterministic output, all three outputs will be bitwise identical. + self.assertTrue((not np.array_equal(y, z)) or + (not np.array_equal(z, w)) or + (not np.array_equal(y, w))) + + def testCategoricalIsNotConstant(self): + def rng(unused_dtype): + return random_ops.multinomial([[1., 1., 1.]], 10) + + dtype = dtypes.float32 + self._testRngIsNotConstant(rng, dtype) + + def testCategoricalIsInRange(self): + for dtype in [dtypes.float32, dtypes.float64]: + with self.test_session() as sess: + with self.test_scope(): + x = random_ops.multinomial( + array_ops.ones(shape=[1, 20], dtype=dtype), 1000) + y = sess.run(x) + self.assertTrue((y >= 0).sum() == 1000) + self.assertTrue((y < 20).sum() == 1000) + + def testSamplingCorrectness(self): + np.random.seed(1618) # Make it reproducible. + num_samples = 21000 + + rand_probs = np.random.dirichlet([1., 1., 2., 3.]) + rand_probs2 = np.random.dirichlet([1., 4., 5.], size=3) # batched + for probs in [[.5, .5], [.85, .05, .1], rand_probs, rand_probs2]: + probs = np.asarray(probs) + if len(probs.shape) == 1: + probs = probs.reshape(1, probs.size) # singleton batch + + logits = np.log(probs).astype(np.float32) + freqs = self._do_sampling(logits, num_samples) + + # the test here is similar to + # python/kernel_tests/random/multinomial_op_test.py + # Note that df >= 1 in all these cases. Choosing a cutoff of 1e-3 + # corresponds to an alpha value of 2.5% for df = 1, and smaller for larger + # df. + chi2 = self._chi2(probs, freqs) + self.assertLess(chi2, 1e-3) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/compiler/tests/fused_batchnorm_test.py b/tensorflow/compiler/tests/fused_batchnorm_test.py index 936fcf8b6be0f8cd67ba07a8bef9d35a732d30ba..a773b5a94742062511bc8bdc6a202b513ce98db3 100644 --- a/tensorflow/compiler/tests/fused_batchnorm_test.py +++ b/tensorflow/compiler/tests/fused_batchnorm_test.py @@ -36,7 +36,7 @@ class FusedBatchNormTest(XLATestCase): x_square = x * x x_square_sum = np.sum(x_square, (0, 1, 2)) x_sum = np.sum(x, axis=(0, 1, 2)) - element_count = np.size(x) / int(np.shape(x)[0]) + element_count = np.size(x) / int(np.shape(x)[-1]) mean = x_sum / element_count var = x_square_sum / element_count - mean * mean normalized = (x - mean) / np.sqrt(var + epsilon) @@ -64,8 +64,9 @@ class FusedBatchNormTest(XLATestCase): return grad_x, grad_scale, grad_offset def testInference(self): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -74,8 +75,8 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y_ref, mean_ref, var_ref = self._reference_training( x_val, scale_val, offset_val, epsilon, data_format) @@ -97,8 +98,9 @@ class FusedBatchNormTest(XLATestCase): self.assertAllClose(y_val, y_ref, atol=1e-3) def _testLearning(self, use_gradient_checker): - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) @@ -109,8 +111,8 @@ class FusedBatchNormTest(XLATestCase): with self.test_session() as sess, self.test_scope(): # To avoid constant folding t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x") - scale = array_ops.placeholder(np.float32, shape=[2], name="scale") - offset = array_ops.placeholder(np.float32, shape=[2], name="offset") + scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale") + offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset") epsilon = 0.001 y, mean, var = nn.fused_batch_norm( t_val, @@ -154,8 +156,9 @@ class FusedBatchNormTest(XLATestCase): def testGradient(self): # TODO(b/64270657): Use gradient_checker here in addition to comparing with # this reference implementation. - x_shape = [2, 2, 6, 2] - scale_shape = [2] + channel = 3 + x_shape = [2, 2, 6, channel] + scale_shape = [channel] grad_val = np.random.random_sample(x_shape).astype(np.float32) x_val = np.random.random_sample(x_shape).astype(np.float32) scale_val = np.random.random_sample(scale_shape).astype(np.float32) diff --git a/tensorflow/compiler/tests/gather_test.py b/tensorflow/compiler/tests/gather_test.py index 9f752dd072bd90b02c2ab801a09a6d17f8ea0e58..13cbe6f312f5175edaec28fa7a8f28064194b0e9 100644 --- a/tensorflow/compiler/tests/gather_test.py +++ b/tensorflow/compiler/tests/gather_test.py @@ -24,9 +24,11 @@ from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import flags from tensorflow.python.platform import test -_TEST_TYPES = [dtypes.float32] +FLAGS = flags.FLAGS class GatherTest(xla_test.XLATestCase): @@ -42,8 +44,8 @@ class GatherTest(xla_test.XLATestCase): def testScalar1D(self): with self.test_session() as session, self.test_scope(): data = np.array([0, 1, 2, 3, 7, 5]) - for dtype in _TEST_TYPES: - for indices in 4, [1, 2, 2, 4, 5]: + for dtype in self.all_tf_types: + for indices in 4, [4], [1, 2, 2, 4, 5]: params_np = self._buildParams(data, dtype) params = array_ops.placeholder(dtype=dtype) indices_tf = constant_op.constant(indices) @@ -51,55 +53,137 @@ class GatherTest(xla_test.XLATestCase): gather_val = session.run(gather_t, feed_dict={params: params_np}) np_val = params_np[indices] self.assertAllEqual(np_val, gather_val) - self.assertEqual(np_val.shape, gather_val.shape) def testScalar2D(self): with self.test_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) - for dtype in _TEST_TYPES: - params_np = self._buildParams(data, dtype) - params = array_ops.placeholder(dtype=dtype) - indices = constant_op.constant(2) - gather_t = array_ops.gather(params, indices) - gather_val = session.run(gather_t, feed_dict={params: params_np}) - self.assertAllEqual(np.take(params_np, 2, axis=0), gather_val) - expected_shape = data.shape[:0] + data.shape[1:] - self.assertEqual(expected_shape, gather_val.shape) + for dtype in self.all_tf_types: + for axis in 0, 1, -1: + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices = constant_op.constant(2) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + expected = np.take(params_np, 2, axis=axis) + self.assertAllEqual(expected, gather_val) def testSimpleTwoD32(self): with self.test_session() as session, self.test_scope(): data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], [12, 13, 14]]) - for dtype in _TEST_TYPES: - params_np = self._buildParams(data, dtype) - params = array_ops.placeholder(dtype=dtype) - # The indices must be in bounds for any axis. - indices = constant_op.constant([0, 1, 0, 2]) - gather_t = array_ops.gather(params, indices) - gather_val = session.run(gather_t, feed_dict={params: params_np}) - self.assertAllEqual( - np.take(params_np, [0, 1, 0, 2], axis=0), gather_val) - expected_shape = data.shape[:0] + (4,) + data.shape[1:] - self.assertEqual(expected_shape, gather_val.shape) + for dtype in self.all_tf_types: + for axis in 0, 1, -1: + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + # The indices must be in bounds for any axis. + indices = constant_op.constant([0, 1, 0, 2]) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = session.run(gather_t, feed_dict={params: params_np}) + expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + self.assertAllEqual(expected, gather_val) + + def testSimpleTwoD32_Int64Indices(self): + if np.int64 not in self.int_types: + return + + with self.test_session() as session, self.test_scope(): + data = np.array([[0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 11], + [12, 13, 14]]) + # The indices must be in bounds for any axis. + indices_np = np.array([0, 1, 0, 2]) + for dtype in self.all_tf_types: + for axis in 0, 1, -1: + params_np = self._buildParams(data, dtype) + params = array_ops.placeholder(dtype=dtype) + indices = array_ops.placeholder(dtype=dtypes.int64) + gather_t = array_ops.gather(params, indices, axis=axis) + gather_val = session.run( + gather_t, feed_dict={ + params: params_np, + indices: indices_np + }) + expected = np.take(params_np, [0, 1, 0, 2], axis=axis) + self.assertAllEqual(expected, gather_val) def testHigherRank(self): - # Check that scalar and empty indices shapes work as well. + """Check that scalar and empty indices shapes work as well.""" shape = (2, 1, 3, 2) for indices_shape in (), (0,), (2, 0), (2, 3): - for dtype in _TEST_TYPES: - params = self._buildParams(np.random.randn(*shape), dtype) - indices = np.random.randint(shape[0], size=indices_shape) - with self.test_session() as sess, self.test_scope(): - tf_params = array_ops.placeholder(dtype=dtype) - tf_indices = constant_op.constant(indices, dtype=dtypes.int32) - gather = array_ops.gather(tf_params, tf_indices) - gather_value = sess.run(gather, feed_dict={tf_params: params}) - gather_np = np.take(params, indices, 0) - self.assertAllEqual(gather_np, gather_value) - expected_shape = (params.shape[:0] + indices.shape + params.shape[1:]) - self.assertEqual(expected_shape, gather_value.shape) - - -if __name__ == "__main__": + for dtype in self.all_tf_types: + for axis in 0, 1, 2, 3, -1, -2: + params = self._buildParams(np.random.randn(*shape), dtype) + indices = np.random.randint(shape[axis], size=indices_shape) + with self.test_session() as sess, self.test_scope(): + tf_params = array_ops.placeholder(dtype=dtype) + tf_indices = constant_op.constant(indices, dtype=dtypes.int32) + gather = array_ops.gather(tf_params, tf_indices, axis=axis) + gather_value = sess.run(gather, feed_dict={tf_params: params}) + gather_np = np.take(params, indices, axis=axis) + self.assertAllEqual(gather_np, gather_value) + + +class GatherBenchmark(test.Benchmark): + """Microbenchmarks for the gather op.""" + + def _benchmarkGather(self, name, axis, gather_indices, use_xla_jit): + + def BuilderFn(): + inputs = variables.Variable( + array_ops.zeros([100, 100, 10, 100, 50], dtype=dtypes.float32), + dtype=dtypes.float32, + name='input') + indices = variables.Variable( + gather_indices, dtype=dtypes.int32, name='indices') + gather_t = array_ops.gather(inputs, indices, axis=axis) + return '%s.axis%d' % (name, axis), [gather_t] + + xla_test.Benchmark(self, BuilderFn, use_xla_jit=use_xla_jit, device='cpu') + + def _benchmarkSliceGather(self, axis, use_xla_jit): + """Benchmarks a gather op that's really a dynamic slice.""" + self._benchmarkGather('slice_gather', axis, [1], use_xla_jit) + + def _benchmarkNontrivialGather(self, axis, use_xla_jit): + self._benchmarkGather('nontrivial_gather', axis, [9, 1, 0, 2] * 4, + use_xla_jit) + + def benchmarkSliceGatherAxis0(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=False) + + def benchmarkSliceGatherAxis0XLA(self): + self._benchmarkSliceGather(axis=0, use_xla_jit=True) + + def benchmarkSliceGatherAxis1(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=False) + + def benchmarkSliceGatherAxis1XLA(self): + self._benchmarkSliceGather(axis=1, use_xla_jit=True) + + def benchmarkSliceGatherAxis4(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=False) + + def benchmarkSliceGatherAxis4XLA(self): + self._benchmarkSliceGather(axis=4, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis0(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis0XLA(self): + self._benchmarkNontrivialGather(axis=0, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis1(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis1XLA(self): + self._benchmarkNontrivialGather(axis=1, use_xla_jit=True) + + def benchmarkNontrivialGatherAxis4(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=False) + + def benchmarkNontrivialGatherAxis4XLA(self): + self._benchmarkNontrivialGather(axis=4, use_xla_jit=True) + + +if __name__ == '__main__': test.main() diff --git a/tensorflow/compiler/tests/jit_test.py b/tensorflow/compiler/tests/jit_test.py index 11914080eccbf3506e6a17e243bf9f8ba1cbb812..2d8236e2cbdfafb35626cd582ee39b1f917aec7f 100644 --- a/tensorflow/compiler/tests/jit_test.py +++ b/tensorflow/compiler/tests/jit_test.py @@ -21,15 +21,12 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.compiler import jit -from tensorflow.core.framework import function_pb2 -from tensorflow.core.framework import node_def_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session as session_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -118,31 +115,13 @@ class JitLaunchTest(test.TestCase): def testNoOutputs(self): with session_lib.Session() as sess: - # Build a function with a single Const node, whose output is ignored. - fdef = function_pb2.FunctionDef() - fdef.signature.name = "KernelWithNoOutputs" - node = node_def_pb2.NodeDef() - node.op = "Const" - node.name = "ignored" - node.attr["dtype"].type = dtypes.int32.as_datatype_enum - tensor = tensor_util.make_tensor_proto([0], dtype=dtypes.int32, shape=[]) - node.attr["value"].tensor.CopyFrom(tensor) - fdef.node_def.extend([node]) # Check that calling the result as a compiled kernel doesn't crash. @function.Defun(compiled=True) def KernelWithNoOutputs(): - return constant_op.constant(100) - - # Hack to override the definition. By accessing .definition, we - # force the _DefinedFunction initialized internally. Then, we - # replace it's internal FunctionDef proto. We do this hack here - # because one typically can't construct KernelWithNoOutputs - # function via Defun decorator directly. - _ = KernelWithNoOutputs.definition - foo = KernelWithNoOutputs - foo._definition = fdef - call = KernelWithNoOutputs() + a = constant_op.constant(100) # pylint: disable=unused-variable + + call = KernelWithNoOutputs() # pylint: disable=assignment-from-no-return sess.run(call, {}) def testAliasing(self): diff --git a/tensorflow/compiler/tests/nary_ops_test.py b/tensorflow/compiler/tests/nary_ops_test.py index 2660e1d5728caf88e2b9ae73b3e3fde2aee71ed8..e4843b169b943b63346b783ddc50039030988ca5 100644 --- a/tensorflow/compiler/tests/nary_ops_test.py +++ b/tensorflow/compiler/tests/nary_ops_test.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import unittest + import numpy as np from tensorflow.compiler.tests.xla_test import XLATestCase @@ -29,7 +31,7 @@ from tensorflow.python.platform import googletest class NAryOpsTest(XLATestCase): - def _testNAry(self, op, args, expected): + def _testNAry(self, op, args, expected, equality_fn=None): with self.test_session() as session: with self.test_scope(): placeholders = [ @@ -39,7 +41,17 @@ class NAryOpsTest(XLATestCase): feeds = {placeholders[i]: args[i] for i in range(0, len(args))} output = op(placeholders) result = session.run(output, feeds) - self.assertAllClose(result, expected, rtol=1e-3) + if not equality_fn: + equality_fn = self.assertAllClose + equality_fn(result, expected, rtol=1e-3) + + def _nAryListCheck(self, results, expected, **kwargs): + self.assertEqual(len(results), len(expected)) + for (r, e) in zip(results, expected): + self.assertAllClose(r, e, **kwargs) + + def _testNAryLists(self, op, args, expected): + self._testNAry(op, args, expected, equality_fn=self._nAryListCheck) def testFloat(self): self._testNAry(math_ops.add_n, @@ -56,6 +68,44 @@ class NAryOpsTest(XLATestCase): np.array([42], dtype=np.float32)], expected=np.array([48], dtype=np.float32)) + def testComplex(self): + for dtype in self.complex_types: + self._testNAry( + math_ops.add_n, [np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)], + expected=np.array([[1 + 2j, 2 - 3j, 3 + 4j]], dtype=dtype)) + + self._testNAry( + math_ops.add_n, [ + np.array([1 + 2j, 2 - 3j], dtype=dtype), + np.array([10j, 20], dtype=dtype) + ], + expected=np.array([1 + 12j, 22 - 3j], dtype=dtype)) + self._testNAry( + math_ops.add_n, [ + np.array([-4, 5j], dtype=dtype), + np.array([2 + 10j, -2], dtype=dtype), + np.array([42j, 3 + 3j], dtype=dtype) + ], + expected=np.array([-2 + 52j, 1 + 8j], dtype=dtype)) + + @unittest.skip("IdentityN is temporarily CompilationOnly as workaround") + def testIdentityN(self): + self._testNAryLists(array_ops.identity_n, + [np.array([[1, 2, 3]], dtype=np.float32)], + expected=[np.array([[1, 2, 3]], dtype=np.float32)]) + self._testNAryLists(array_ops.identity_n, + [np.array([[1, 2], [3, 4]], dtype=np.float32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], + expected=[ + np.array([[1, 2], [3, 4]], dtype=np.float32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) + self._testNAryLists(array_ops.identity_n, + [np.array([[1], [2], [3], [4]], dtype=np.int32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)], + expected=[ + np.array([[1], [2], [3], [4]], dtype=np.int32), + np.array([[3, 2, 1], [6, 5, 1]], dtype=np.float32)]) + def testConcat(self): self._testNAry( lambda x: array_ops.concat(x, 0), [ diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index a17a3f3d6536eea780106d84bcf4ce92c0fd017e..d6c93088d4efff7d8306e262a79ae49d3d8ac722 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -29,6 +29,9 @@ from tensorflow.python.platform import googletest class RandomOpsTest(XLATestCase): """Test cases for random-number generating operators.""" + def _random_types(self): + return set(self.numeric_types) - set(self.complex_types) + def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. with self.test_session() as sess: @@ -51,7 +54,8 @@ class RandomOpsTest(XLATestCase): def rng(dtype): return random_ops.random_uniform(shape=[2], dtype=dtype, maxval=1000000) - for dtype in self.numeric_types: + + for dtype in self._random_types(): self._testRngIsNotConstant(rng, dtype) def testRandomNormalIsNotConstant(self): @@ -63,7 +67,7 @@ class RandomOpsTest(XLATestCase): self._testRngIsNotConstant(rng, dtype) def testRandomUniformIsInRange(self): - for dtype in self.numeric_types: + for dtype in self._random_types(): with self.test_session() as sess: with self.test_scope(): x = random_ops.random_uniform(shape=[1000], dtype=dtype, minval=-2, diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc index cb6f735a27a2e0fda04826ebfc51c63b342f128a..6a8c3bcd55a6e454a19b6249cf4eb48739c8657f 100644 --- a/tensorflow/compiler/tests/randomized_tests.cc +++ b/tensorflow/compiler/tests/randomized_tests.cc @@ -32,7 +32,6 @@ limitations under the License. // --tf_xla_test_repetitions=20 // TODO(phawkins): add tests for: -// * ArgMax // * DepthwiseConv2DNative // * Gather // * InvertPermutation @@ -76,7 +75,7 @@ namespace { // Command line flags: see main() below. int64 tf_xla_random_seed = 0; int32 tf_xla_test_repetitions = 20; -int64 tf_xla_max_tensor_size = 100000LL; +int64 tf_xla_max_tensor_size = 10000LL; string* tf_xla_test_device_ptr; // initial value set in main() bool tf_xla_test_use_jit = true; @@ -84,8 +83,8 @@ string LocalDeviceToFullDeviceName(const string& device) { return strings::StrCat("/job:localhost/replica:0/task:0/device:", device); } -constexpr std::array kAllXlaTypes = { - {DT_INT32, DT_FLOAT, DT_BOOL}}; +constexpr std::array kAllXlaTypes = { + {DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}}; // An OpTestBuilder is a graph builder class that takes as input an operator to // test, its inputs and attributes, and builds a graph that executes the @@ -368,11 +367,11 @@ OpTest::OpTest() { void OpTest::Repeatedly(const std::function& fn) { int const max_repetitions = tf_xla_test_repetitions; int valid_test_runs = 0; - // We run up to 20 * max_repetitions times; the idea is that if we roll the + // We run up to 100 * max_repetitions times; the idea is that if we roll the // dice enough times we will find some valid parameters. We want to put an // upper limit on the number iterations just in case the probability of // finding feasible parameters is very low. - for (int i = 0; !HasFailure() && i < max_repetitions * 20 && + for (int i = 0; !HasFailure() && i < max_repetitions * 100 && valid_test_runs < max_repetitions; ++i) { TestResult result = fn(); @@ -450,6 +449,13 @@ Tensor OpTest::RandomTensor(DataType dtype, gtl::ArraySlice shape) { }); break; } + case DT_COMPLEX64: { + std::uniform_real_distribution distribution(-1.0f, 1.0f); + test::FillFn(&tensor, [this, &distribution](int i) { + return complex64(distribution(generator()), distribution(generator())); + }); + break; + } case DT_INT32: { std::uniform_int_distribution distribution(-(1 << 20), 1 << 20); test::FillFn(&tensor, [this, &distribution](int i) -> int32 { @@ -625,11 +631,47 @@ std::vector OpTest::AsInt32s(const std::vector& int64s) { // Functions for comparing tensors. +template +double Abs(T x) { + return std::fabs(x); +} + +template <> +double Abs(complex64 x) { + return std::abs(x); +} + template bool IsClose(const T& x, const T& y, double atol, double rtol) { if (std::isnan(x) && std::isnan(y)) return true; if (x == y) return true; // Allow inf == inf. - return fabs(x - y) < atol + rtol * fabs(x); + return Abs(x - y) < atol + rtol * Abs(x); +} + +template <> +bool IsClose(const complex64& x, const complex64& y, double atol, + double rtol) { + if (std::isnan(x.real()) && std::isnan(y.real())) { + if (std::isnan(x.imag()) && std::isnan(y.imag())) { + return true; + } + if (x.imag() == y.imag()) return true; // Allow inf == inf. + return Abs(x.imag() - y.imag()) < atol + rtol * Abs(x.imag()); + } else if (std::isnan(x.imag()) && std::isnan(y.imag())) { + if (x.real() == y.real()) return true; // Allow inf == inf. + return Abs(x.real() - y.real()) < atol + rtol * Abs(x.real()); + } + if (x == y) return true; // Allow inf == inf. + return Abs(x - y) < atol + rtol * Abs(x); +} + +template +string Str(T x) { + return strings::StrCat(x); +} +template <> +string Str(complex64 x) { + return strings::StrCat("(", x.real(), ", ", x.imag(), ")"); } template @@ -640,9 +682,10 @@ Status TensorsAreCloseImpl(const Tensor& x, const Tensor& y, double atol, for (int i = 0; i < Tx.size(); ++i) { if (!IsClose(Tx(i), Ty(i), atol, rtol)) { return errors::InvalidArgument(strings::StrCat( - i, "-th tensor element isn't close: ", Tx(i), " vs. ", Ty(i), - ". x = ", x.DebugString(), "y = ", y.DebugString(), "atol = ", atol, - " rtol = ", rtol, " tol = ", atol + rtol * std::fabs(Tx(i)))); + i, "-th tensor element isn't close: ", Str(Tx(i)), " vs. ", + Str(Ty(i)), ". x = ", x.DebugString(), "y = ", y.DebugString(), + "atol = ", atol, " rtol = ", rtol, + " tol = ", atol + rtol * Abs(Tx(i)))); } } return Status::OK(); @@ -684,6 +727,8 @@ Status TensorsAreClose(const Tensor& a, const Tensor& b, double atol, return TensorsAreCloseImpl(a, b, atol, rtol); case DT_DOUBLE: return TensorsAreCloseImpl(a, b, atol, rtol); + case DT_COMPLEX64: + return TensorsAreCloseImpl(a, b, atol, rtol); case DT_INT32: return TensorsAreEqualImpl(a, b); case DT_INT64: @@ -823,7 +868,7 @@ Tensor AsIntTensor(DataType dtype, const std::vector& values) { TEST_F(OpTest, Abs) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Abs").RandomInput(type).Attr("T", type)); }); @@ -838,7 +883,7 @@ TEST_F(OpTest, Acosh) { TEST_F(OpTest, Add) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Add") .RandomInput(type, dims.first) @@ -849,7 +894,7 @@ TEST_F(OpTest, Add) { TEST_F(OpTest, AddN) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); int n = std::uniform_int_distribution(1, 5)(generator()); auto shape = RandomDims(); @@ -876,6 +921,14 @@ TEST_F(OpTest, All) { }); } +TEST_F(OpTest, Angle) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Angle") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, Any) { Repeatedly([this]() { std::vector data_dims = RandomDims(); @@ -890,14 +943,47 @@ TEST_F(OpTest, Any) { TEST_F(OpTest, ApproximateEqual) { Repeatedly([this]() { - auto dims = RandomDims(); + auto dims = BroadcastableDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ApproximateEqual") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) .Attr("T", DT_FLOAT)); }); } +TEST_F(OpTest, ArgMax) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, 5, 1); + int num_dims = dims.size(); + int reduce_dim = + std::uniform_int_distribution(-num_dims, num_dims)(generator()); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ArgMax") + .RandomInput(DT_FLOAT, dims) + .Input(test::AsScalar(reduce_dim)) + .Attr("T", DT_FLOAT) + .Attr("Tidx", DT_INT32) + .Attr("output_type", DT_INT32)); + }); +} + +TEST_F(OpTest, ArgMin) { + Repeatedly([this]() { + std::vector dims = RandomDims(1, 5, 1); + int num_dims = dims.size(); + int reduce_dim = + std::uniform_int_distribution(-num_dims, num_dims)(generator()); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("ArgMin") + .RandomInput(DT_FLOAT, dims) + .Input(test::AsScalar(reduce_dim)) + .Attr("T", DT_FLOAT) + .Attr("Tidx", DT_INT32) + .Attr("output_type", DT_INT32)); + }); +} + TEST_F(OpTest, Asinh) { Repeatedly([this]() { return ExpectTfAndXlaOutputsAreClose( @@ -912,6 +998,16 @@ TEST_F(OpTest, Atanh) { }); } +TEST_F(OpTest, Atan2) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Atan2") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, AvgPool) { Repeatedly([this]() { std::uniform_int_distribution random_int(1, 5); @@ -1007,6 +1103,7 @@ TEST_F(OpTest, AvgPool3DGrad) { TEST_F(OpTest, BatchMatMul) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); std::vector output_dims = RandomDims(2, 5, 0, 7); int64 ndims = output_dims.size(); int64 inner_dim = RandomDim(); @@ -1025,9 +1122,9 @@ TEST_F(OpTest, BatchMatMul) { } return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchMatMul") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type) .Attr("adj_x", adj_x) .Attr("adj_y", adj_y)); }); @@ -1059,10 +1156,11 @@ TEST_F(OpTest, BatchToSpace) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BatchToSpace") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(crops) - .Attr("T", DT_FLOAT) + .Attr("T", type) .Attr("block_size", block_size)); }); } @@ -1096,13 +1194,14 @@ TEST_F(OpTest, BatchToSpaceND) { CHECK(crops.CopyFrom(AsIntTensor(DT_INT32, crop_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("BatchToSpaceND") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(crops) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } @@ -1111,18 +1210,20 @@ TEST_F(OpTest, BiasAdd) { auto x_dims = RandomDims(2, kDefaultMaxRank); auto y_dims = {x_dims[x_dims.size() - 1]}; // TODO(phawkins): test both data formats. + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAdd") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type)); }); } TEST_F(OpTest, BiasAddGrad) { Repeatedly([this]() { // TODO(phawkins): test both data formats. + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("BiasAddGrad").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("BiasAddGrad").RandomInput(type).Attr("T", type)); }); } @@ -1130,17 +1231,54 @@ TEST_F(OpTest, BiasAddV1) { Repeatedly([this]() { auto x_dims = RandomDims(2, kDefaultMaxRank); auto y_dims = {x_dims[x_dims.size() - 1]}; + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BiasAddV1") - .RandomInput(DT_FLOAT, x_dims) - .RandomInput(DT_FLOAT, y_dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, x_dims) + .RandomInput(type, y_dims) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseAnd) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseAnd") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BitwiseOr) { + Repeatedly([this]() { + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("BitwiseOr") + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); + }); +} + +TEST_F(OpTest, BroadcastArgs) { + Repeatedly([this]() { + // TODO(phawkins): only int32 seems to be implemented in Tensorflow. + // auto type = Choose({DT_INT32, DT_INT64}); + DataType type = DT_INT32; + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("BroadcastArgs") + .Input(AsIntTensor(type, dims.first)) + .Input(AsIntTensor(type, dims.second)) + .Attr("T", type)); }); } TEST_F(OpTest, BroadcastGradientArgs) { Repeatedly([this]() { // TODO(phawkins): only int32 seems to be implemented in Tensorflow. - // DataType type = Choose({DT_INT32, DT_INT64}); + // auto type = Choose({DT_INT32, DT_INT64}); DataType type = DT_INT32; auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose( @@ -1154,8 +1292,8 @@ TEST_F(OpTest, BroadcastGradientArgs) { TEST_F(OpTest, Cast) { Repeatedly([this]() { DataType src_type, dst_type; - src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); - dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL}); + src_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}); + dst_type = Choose({DT_INT32, DT_FLOAT, DT_BOOL, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Cast") .RandomInput(src_type) .Attr("SrcT", src_type) @@ -1170,9 +1308,19 @@ TEST_F(OpTest, Ceil) { }); } +TEST_F(OpTest, Complex) { + Repeatedly([this]() { + auto dims = BroadcastableDims(); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Complex") + .RandomInput(DT_FLOAT, dims.first) + .RandomInput(DT_FLOAT, dims.second) + .Attr("T", DT_FLOAT)); + }); +} + TEST_F(OpTest, Concat) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(2, 5)(generator()); std::vector dims = RandomDims(1); @@ -1212,6 +1360,14 @@ TEST_F(OpTest, ConcatOffset) { }); } +TEST_F(OpTest, Conj) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Conj") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, Conv2D) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1226,11 +1382,12 @@ TEST_F(OpTest, Conv2D) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2D") - .RandomInput(DT_FLOAT, data_dims) - .RandomInput(DT_FLOAT, kernel_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, data_dims) + .RandomInput(type, kernel_dims) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1250,12 +1407,13 @@ TEST_F(OpTest, Conv2DBackpropFilter) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out})); + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropFilter") - .RandomInput(DT_FLOAT, activations) + .RandomInput(type, activations) .Input(kernel_shape) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1275,12 +1433,13 @@ TEST_F(OpTest, Conv2DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv2DBackpropInput") .Input(in_shape) - .RandomInput(DT_FLOAT, kernel) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, kernel) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); @@ -1298,11 +1457,12 @@ TEST_F(OpTest, Conv3D) { std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3D") - .RandomInput(DT_FLOAT, data) - .RandomInput(DT_FLOAT, kernel) - .Attr("T", DT_FLOAT) + .RandomInput(type, data) + .RandomInput(type, kernel) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); @@ -1322,12 +1482,13 @@ TEST_F(OpTest, Conv3DBackpropFilter) { Tensor kernel_shape = test::AsTensor( AsInt32s({d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out})); + DataType type = DT_FLOAT; // TODO(b/65408531): COMPLEX_64 support return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropFilterV2") - .RandomInput(DT_FLOAT, activations) + .RandomInput(type, activations) .Input(kernel_shape) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); @@ -1346,17 +1507,49 @@ TEST_F(OpTest, Conv3DBackpropInput) { ImageDims(FORMAT_NHWC, batch, features_out, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], d.kernel_dims[2], features_in, features_out}; + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Conv3DBackpropInputV2") .Input(in_shape) - .RandomInput(DT_FLOAT, kernel) - .RandomInput(DT_FLOAT, backprop) - .Attr("T", DT_FLOAT) + .RandomInput(type, kernel) + .RandomInput(type, backprop) + .Attr("T", type) .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); } +TEST_F(OpTest, Cos) { + Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cos").RandomInput(type).Attr("T", type)); + }); +} + +TEST_F(OpTest, Cosh) { + Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Cosh").RandomInput(type).Attr("T", type)); + }); +} + +TEST_F(OpTest, DepthToSpace) { + Repeatedly([this]() { + int64 block = RandomDim(2, 5); + std::vector input_dims = RandomDims(4, 4); + input_dims[1] = (input_dims[1] + (block - 1)) / block; + input_dims[2] = (input_dims[2] + (block - 1)) / block; + input_dims[3] *= block * block; + auto type = Choose(kAllXlaTypes); + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("DepthToSpace") + .RandomInput(type, input_dims) + .Attr("T", type) + .Attr("block_size", block)); + }); +} + TEST_F(OpTest, DepthwiseConv2DNative) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1368,12 +1561,14 @@ TEST_F(OpTest, DepthwiseConv2DNative) { std::vector kernel_dims = {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}; + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNative") .RandomInput(DT_FLOAT, input_dims) .RandomInput(DT_FLOAT, kernel_dims) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID")); }); } @@ -1391,32 +1586,20 @@ TEST_F(OpTest, DepthwiseConv2DBackpropFilter) { FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); Tensor kernel_shape = test::AsTensor(AsInt32s( {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier})); + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNativeBackpropFilter") .RandomInput(DT_FLOAT, activations) .Input(kernel_shape) .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); } -TEST_F(OpTest, Cos) { - Repeatedly([this]() { - return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Cos").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); - }); -} - -TEST_F(OpTest, Cosh) { - Repeatedly([this]() { - return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Cosh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); - }); -} - TEST_F(OpTest, DepthwiseConv2DBackpropInput) { Repeatedly([this]() { WindowedSpatialDims d = ChooseWindowedSpatialDims(2); @@ -1430,13 +1613,15 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) { FORMAT_NHWC, batch, features_in * depth_multiplier, d.output_dims); std::vector kernel = {d.kernel_dims[0], d.kernel_dims[1], features_in, depth_multiplier}; + std::vector strides = ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims); + strides[2] = strides[1]; // Current impl only supports equal strides return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("DepthwiseConv2dNativeBackpropInput") .Input(in_shape) .RandomInput(DT_FLOAT, kernel) .RandomInput(DT_FLOAT, backprop) .Attr("T", DT_FLOAT) - .Attr("strides", ImageDims(FORMAT_NHWC, 1, 1, d.stride_dims)) + .Attr("strides", strides) .Attr("padding", d.padding == SAME ? "SAME" : "VALID") .Attr("data_format", "NHWC")); }); @@ -1444,7 +1629,7 @@ TEST_F(OpTest, DepthwiseConv2DBackpropInput) { TEST_F(OpTest, Diag) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); std::vector dims; // Diag causes a quadratic blowup in output size. int64 size; @@ -1459,7 +1644,7 @@ TEST_F(OpTest, Diag) { TEST_F(OpTest, DiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); auto dims = RandomDims(1, 3); // Duplicate the random dims. std::vector doubled_dims(dims.size() * 2); @@ -1473,7 +1658,7 @@ TEST_F(OpTest, DiagPart) { TEST_F(OpTest, Div) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Div") .RandomInput(type, dims.first) @@ -1484,7 +1669,7 @@ TEST_F(OpTest, Div) { TEST_F(OpTest, DynamicStitch) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(2, 5)(generator()); OpTestBuilder builder("DynamicStitch"); builder.Attr("T", type); @@ -1569,7 +1754,7 @@ TEST_F(OpTest, SeluGrad) { TEST_F(OpTest, Equal) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Equal") .RandomInput(type, dims.first) @@ -1580,21 +1765,23 @@ TEST_F(OpTest, Equal) { TEST_F(OpTest, Exp) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Exp").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Exp").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Expm1) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Expm1").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Expm1").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, ExpandDims) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector in_dims = RandomDims(); Tensor dim(DT_INT32, TensorShape()); std::uniform_int_distribution d(-1 - in_dims.size(), in_dims.size()); @@ -1608,7 +1795,7 @@ TEST_F(OpTest, ExpandDims) { TEST_F(OpTest, Fill) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); std::vector shape(dims.begin(), dims.end()); return ExpectTfAndXlaOutputsAreClose( @@ -1639,7 +1826,7 @@ TEST_F(OpTest, FloorDiv) { TEST_F(OpTest, FloorMod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("FloorMod") .RandomInput(type, dims.first) @@ -1650,7 +1837,7 @@ TEST_F(OpTest, FloorMod) { TEST_F(OpTest, Greater) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Greater") .RandomInput(type, dims.first) @@ -1661,7 +1848,7 @@ TEST_F(OpTest, Greater) { TEST_F(OpTest, GreaterEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("GreaterEqual") .RandomInput(type, dims.first) @@ -1670,6 +1857,22 @@ TEST_F(OpTest, GreaterEqual) { }); } +TEST_F(OpTest, Imag) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Imag") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + +TEST_F(OpTest, Invert) { + Repeatedly([this]() { + DataType type = DT_INT32; + return ExpectTfAndXlaOutputsAreClose( + OpTestBuilder("Invert").RandomInput(type).Attr("T", type)); + }); +} + TEST_F(OpTest, L2Loss) { Repeatedly([this]() { DataType type = DT_FLOAT; @@ -1680,7 +1883,7 @@ TEST_F(OpTest, L2Loss) { TEST_F(OpTest, Less) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Less") .RandomInput(type, dims.first) @@ -1691,7 +1894,7 @@ TEST_F(OpTest, Less) { TEST_F(OpTest, LessEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("LessEqual") .RandomInput(type, dims.first) @@ -1707,7 +1910,7 @@ TEST_F(OpTest, LinSpace) { return test::AsScalar(x); }; std::uniform_int_distribution distribution(-50, 50); - DataType type = Choose({DT_INT32, DT_INT64}); + auto type = Choose({DT_INT32, DT_INT64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("LinSpace") .RandomInput(DT_FLOAT, {}) @@ -1720,15 +1923,17 @@ TEST_F(OpTest, LinSpace) { TEST_F(OpTest, Log) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Log").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Log1p) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Log1p").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Log1p").RandomInput(type).Attr("T", DT_FLOAT)); }); } @@ -1825,10 +2030,11 @@ TEST_F(OpTest, MatMul) { std::swap(b_dims[0], b_dims[1]); } + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatMul") - .RandomInput(DT_FLOAT, a_dims) - .RandomInput(DT_FLOAT, b_dims) - .Attr("T", DT_FLOAT) + .RandomInput(type, a_dims) + .RandomInput(type, b_dims) + .Attr("T", type) .Attr("transpose_a", transpose_a) .Attr("transpose_b", transpose_b)); }); @@ -1836,7 +2042,7 @@ TEST_F(OpTest, MatMul) { TEST_F(OpTest, MatrixDiag) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiag") .RandomInput(type, RandomDims(1)) .Attr("T", type)); @@ -1845,7 +2051,7 @@ TEST_F(OpTest, MatrixDiag) { TEST_F(OpTest, MatrixDiagPart) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("MatrixDiagPart") .RandomInput(type, RandomDims(2)) .Attr("T", type)); @@ -1854,7 +2060,7 @@ TEST_F(OpTest, MatrixDiagPart) { TEST_F(OpTest, Max) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -1868,7 +2074,7 @@ TEST_F(OpTest, Max) { TEST_F(OpTest, Maximum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Maximum") .RandomInput(type, dims.first) @@ -1936,7 +2142,7 @@ TEST_F(OpTest, MaxPool3D) { TEST_F(OpTest, Mean) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); // TODO(phawkins): CPU and XLA differ output for reducing across a // size-0 dimension (nan vs 0). For now, require size >= 1. std::vector data_dims = RandomDims(0, kDefaultMaxRank, 1); @@ -1952,7 +2158,7 @@ TEST_F(OpTest, Mean) { TEST_F(OpTest, Min) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -1966,7 +2172,7 @@ TEST_F(OpTest, Min) { TEST_F(OpTest, Minimum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Minimum") .RandomInput(type, dims.first) @@ -1987,7 +2193,7 @@ TEST_F(OpTest, Mod) { TEST_F(OpTest, Mul) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Mul") .RandomInput(type, dims.first) @@ -1998,7 +2204,7 @@ TEST_F(OpTest, Mul) { TEST_F(OpTest, Neg) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Neg").RandomInput(type).Attr("T", type)); }); @@ -2006,7 +2212,7 @@ TEST_F(OpTest, Neg) { TEST_F(OpTest, NotEqual) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("NotEqual") .RandomInput(type, dims.first) @@ -2017,7 +2223,7 @@ TEST_F(OpTest, NotEqual) { TEST_F(OpTest, OneHot) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); int num_dims = dims.size(); @@ -2047,7 +2253,7 @@ TEST_F(OpTest, OneHot) { TEST_F(OpTest, OnesLike) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("OnesLike").RandomInput(type).Attr("T", type)); }); @@ -2055,7 +2261,7 @@ TEST_F(OpTest, OnesLike) { TEST_F(OpTest, Pack) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(1, 5)(generator()); std::vector dims = RandomDims(); @@ -2077,7 +2283,7 @@ TEST_F(OpTest, Pack) { // TODO(b/31741898): crashes on GPU. TEST_F(OpTest, Pad) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(); // TODO(b/31741996): re-enable DT_INT64 when bug is fixed. @@ -2106,16 +2312,17 @@ TEST_F(OpTest, Pow) { // nontermination. Repeatedly([this]() { auto dims = BroadcastableDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Pow") - .RandomInput(DT_FLOAT, dims.first) - .RandomInput(DT_FLOAT, dims.second) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims.first) + .RandomInput(type, dims.second) + .Attr("T", type)); }); } TEST_F(OpTest, Prod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -2149,15 +2356,23 @@ TEST_F(OpTest, Range) { TEST_F(OpTest, Rank) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Rank").RandomInput(type).Attr("T", type)); }); } +TEST_F(OpTest, Real) { + Repeatedly([this]() { + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Real") + .RandomInput(DT_COMPLEX64) + .Attr("T", DT_COMPLEX64)); + }); +} + TEST_F(OpTest, RealDiv) { Repeatedly([this]() { - DataType type = DT_FLOAT; + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RealDiv") .RandomInput(type, dims.first) @@ -2168,18 +2383,20 @@ TEST_F(OpTest, RealDiv) { TEST_F(OpTest, Reciprocal) { Repeatedly([this]() { + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Reciprocal").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Reciprocal").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, ReciprocalGrad) { Repeatedly([this]() { std::vector dims = RandomDims(); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReciprocalGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Relu) { @@ -2218,7 +2435,7 @@ TEST_F(OpTest, ReluGrad) { TEST_F(OpTest, Reshape) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(); std::bernoulli_distribution random_bool; std::vector dims_before, dims_after; @@ -2246,24 +2463,24 @@ TEST_F(OpTest, Reshape) { TEST_F(OpTest, Reverse) { Repeatedly([this]() { std::vector dims = RandomDims(1); - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); int64 rank = dims.size(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Reverse") .RandomInput(type, dims) .RandomInput(DT_BOOL, {rank}) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } TEST_F(OpTest, ReverseV2) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("ReverseV2") .RandomInput(type, data_dims) .Input(indices) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); }); } @@ -2283,24 +2500,26 @@ TEST_F(OpTest, Round) { TEST_F(OpTest, Rsqrt) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Rsqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Rsqrt").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, RsqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("RsqrtGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Shape) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Shape").RandomInput(type).Attr("T", type)); }); @@ -2308,7 +2527,7 @@ TEST_F(OpTest, Shape) { TEST_F(OpTest, ShapeN) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); int n = std::uniform_int_distribution(1, 5)(generator()); OpTestBuilder builder("ShapeN"); builder.Attr("T", type); @@ -2322,24 +2541,26 @@ TEST_F(OpTest, ShapeN) { TEST_F(OpTest, Sigmoid) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sigmoid").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sigmoid").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, SigmoidGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SigmoidGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Sign) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Sign").RandomInput(type).Attr("T", type)); }); @@ -2347,21 +2568,23 @@ TEST_F(OpTest, Sign) { TEST_F(OpTest, Sin) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sin").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sin").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Sinh) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sinh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sinh").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Size) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Size").RandomInput(type).Attr("T", type)); }); @@ -2369,7 +2592,7 @@ TEST_F(OpTest, Size) { TEST_F(OpTest, Slice) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector begin(data_dims.size()), size(data_dims.size()); @@ -2473,10 +2696,11 @@ TEST_F(OpTest, SpaceToBatch) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToBatch") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(paddings) - .Attr("T", DT_FLOAT) + .Attr("T", type) .Attr("block_size", block_size)); }); } @@ -2514,13 +2738,28 @@ TEST_F(OpTest, SpaceToBatchND) { CHECK(paddings.CopyFrom(AsIntTensor(DT_INT32, padding_vals), TensorShape({num_block_dims, 2}))); + auto type = Choose(kAllXlaTypes); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("SpaceToBatchND") - .RandomInput(DT_FLOAT, input_dims) + .RandomInput(type, input_dims) .Input(test::AsTensor( std::vector(block_dims.begin(), block_dims.end()))) .Input(paddings) - .Attr("T", DT_FLOAT)); + .Attr("T", type)); + }); +} + +TEST_F(OpTest, SpaceToDepth) { + Repeatedly([this]() { + int64 block = RandomDim(2, 5); + std::vector input_dims = RandomDims(4, 4); + // Round spatial dimensions up to a multiple of the block size + input_dims[1] = (input_dims[1] + (block - 1)) / block * block; + input_dims[2] = (input_dims[2] + (block - 1)) / block * block; + return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SpaceToDepth") + .RandomInput(DT_FLOAT, input_dims) + .Attr("T", DT_FLOAT) + .Attr("block_size", block)); }); } @@ -2576,11 +2815,12 @@ TEST_F(OpTest, SparseSoftmaxCrossEntropyWithLogits) { TEST_F(OpTest, Split) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector dims = RandomDims(1); std::uniform_int_distribution ud; int32 dim = std::uniform_int_distribution( - 0, static_cast(dims.size()) - 1)(generator()); + -static_cast(dims.size()), + static_cast(dims.size()) - 1)(generator()); int n = std::uniform_int_distribution(1, 5)(generator()); // Ensure 'dim' is evenly divisible by 'n'. dims[dim] /= n; @@ -2595,18 +2835,20 @@ TEST_F(OpTest, Split) { TEST_F(OpTest, Sqrt) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Sqrt").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Sqrt").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, SqrtGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("SqrtGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } @@ -2622,7 +2864,7 @@ TEST_F(OpTest, SquaredDifference) { TEST_F(OpTest, Square) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("Square").RandomInput(type).Attr("T", type)); }); @@ -2630,7 +2872,7 @@ TEST_F(OpTest, Square) { TEST_F(OpTest, Squeeze) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(0, kDefaultMaxRank, 0, 5); std::bernoulli_distribution random_bool; std::vector squeeze_dims; @@ -2648,7 +2890,7 @@ TEST_F(OpTest, Squeeze) { TEST_F(OpTest, Sub) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("Sub") .RandomInput(type, dims.first) @@ -2659,7 +2901,7 @@ TEST_F(OpTest, Sub) { TEST_F(OpTest, Sum) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); std::vector data_dims = RandomDims(); Tensor indices = RandomReductionIndices(data_dims.size()); bool keep_dims = Choose({false, true}); @@ -2673,7 +2915,7 @@ TEST_F(OpTest, Sum) { TEST_F(OpTest, StridedSlice) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector begin(data_dims.size()), end(data_dims.size()); std::vector strides(data_dims.size()); @@ -2718,7 +2960,7 @@ TEST_F(OpTest, StridedSlice) { TEST_F(OpTest, StridedSliceGrad) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); // Dimensions of the forward input. std::vector dims = RandomDims(); @@ -2771,31 +3013,34 @@ TEST_F(OpTest, StridedSliceGrad) { TEST_F(OpTest, Tan) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Tan").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Tan").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, Tanh) { Repeatedly([this]() { + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( - OpTestBuilder("Tanh").RandomInput(DT_FLOAT).Attr("T", DT_FLOAT)); + OpTestBuilder("Tanh").RandomInput(type).Attr("T", type)); }); } TEST_F(OpTest, TanhGrad) { Repeatedly([this]() { auto dims = RandomDims(); + auto type = Choose({DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TanhGrad") - .RandomInput(DT_FLOAT, dims) - .RandomInput(DT_FLOAT, dims) - .Attr("T", DT_FLOAT)); + .RandomInput(type, dims) + .RandomInput(type, dims) + .Attr("T", type)); }); } TEST_F(OpTest, Tile) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector t_dims = RandomDims(1); std::vector multiples(t_dims.size()); for (int i = 0; i < t_dims.size(); ++i) { @@ -2811,7 +3056,7 @@ TEST_F(OpTest, Tile) { TEST_F(OpTest, Transpose) { Repeatedly([this]() { - DataType type = Choose(kAllXlaTypes); + auto type = Choose(kAllXlaTypes); std::vector data_dims = RandomDims(); std::vector perm(data_dims.size()); std::iota(perm.begin(), perm.end(), 0); @@ -2836,7 +3081,7 @@ TEST_F(OpTest, TruncateDiv) { TEST_F(OpTest, TruncateMod) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT}); auto dims = BroadcastableDims(); return ExpectTfAndXlaOutputsAreClose(OpTestBuilder("TruncateMod") .RandomInput(type, dims.first) @@ -2847,7 +3092,7 @@ TEST_F(OpTest, TruncateMod) { TEST_F(OpTest, ZerosLike) { Repeatedly([this]() { - DataType type = Choose({DT_INT32, DT_FLOAT}); + auto type = Choose({DT_INT32, DT_FLOAT, DT_COMPLEX64}); return ExpectTfAndXlaOutputsAreClose( OpTestBuilder("ZerosLike").RandomInput(type).Attr("T", type)); }); diff --git a/tensorflow/compiler/tests/slice_ops_test.py b/tensorflow/compiler/tests/slice_ops_test.py index 4ddf2ee0dcb2b5f514ff9820c07f7cc10609ff66..a7cbfb04003c397212a35e16c6b23d7c2a18f7df 100644 --- a/tensorflow/compiler/tests/slice_ops_test.py +++ b/tensorflow/compiler/tests/slice_ops_test.py @@ -18,15 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import googletest - class SliceTest(XLATestCase): def test1D(self): @@ -63,6 +60,53 @@ class SliceTest(XLATestCase): self.assertAllEqual([[[6, 5, 4, 3]]], result) + def test3DWithDynamicBegin(self): + """Tests a slice where the start offset is not known at compile time.""" + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + begin = array_ops.placeholder(dtypes.int32, shape=[3]) + with self.test_scope(): + o = array_ops.slice(i, begin, [1, 1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]], + begin: [1, 2, 2] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[6, 5, 4, 3]]], result) + + def test3DWithDynamicBeginAndNegativeSize(self): + """Tests a slice where `begin` is fed dynamically and `size` contains -1.""" + for dtype in self.numeric_types: + with self.test_session(): + i = array_ops.placeholder(dtype, shape=[3, 3, 10]) + begin = array_ops.placeholder(dtypes.int32, shape=[3]) + with self.test_scope(): + o = array_ops.slice(i, begin, [1, -1, 4]) + params = { + i: [[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [9, 8, 7, 6, 5, 4, 3, 2, 1, 0], + [5, 3, 1, 7, 9, 2, 4, 6, 8, 0]], + [[5, 5, 5, 5, 5, 5, 5, 5, 5, 5], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [8, 7, 6, 5, 4, 3, 2, 1, 8, 7]], + [[7, 5, 7, 5, 7, 5, 7, 5, 7, 5], + [1, 2, 1, 2, 1, 2, 1, 2, 1, 2], + [9, 8, 7, 9, 8, 7, 9, 8, 7, 9]]], + begin: [1, 1, 2] + } + result = o.eval(feed_dict=params) + + self.assertAllEqual([[[1, 1, 1, 1], [6, 5, 4, 3]]], result) class StridedSliceTest(XLATestCase): @@ -80,7 +124,7 @@ class StridedSliceTest(XLATestCase): self.assertAllEqual([2, 4], result) - def test1DNegtiveStride(self): + def test1DNegativeStride(self): for dtype in self.numeric_types: with self.test_session(): i = array_ops.placeholder(dtype, shape=[10]) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index ce319d6e69b20b5f1a86ae33c24563394debeea1..76644380bdf2e0c24f6d363ddfaabdff836495d7 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -26,6 +26,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests.xla_test import XLATestCase from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops +from tensorflow.python.ops import bitwise_ops from tensorflow.python.ops import gen_nn_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops @@ -309,11 +310,6 @@ class UnaryOpsTest(XLATestCase): [0.032058604, 0.087144323, 0.23688284, 0.64391428]], dtype=dtype)) - self._assertOpOutputMatchesExpected( - nn_ops.softplus, - np.array([[-2, 0, 8]], dtype=dtype), - expected=np.array([[0.126928, 0.6931472, 8.0003354]], dtype=dtype)) - self._assertOpOutputMatchesExpected( nn_ops.softsign, np.array([[-2, -1, 0, 1, 2]], dtype=dtype), @@ -332,6 +328,138 @@ class UnaryOpsTest(XLATestCase): np.array([-1, -0.5, 0, 0.3], dtype=dtype), expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype)) + def testComplexOps(self): + for dtype in self.complex_types: + # TODO(b/65408531): math_ops.acosh (needs pow) + # TODO(b/65408531): math_ops.asinh (needs pow) + + # TODO(b/65408531): Wider support for log (needs atan2). + atan2_supported = self.device == "XLA_GPU" + if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.atanh, + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype), + expected=np.arctanh( + np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.cosh, + np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype), + expected=np.cosh(np.array([1j, 2 - 3j, 3, 4 + 2j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.sinh, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.sinh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.exp, + np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), + expected=np.exp(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.expm1, + np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype), + expected=np.expm1(np.array([[-1 + 2j, 3j, 2 - 3j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.reciprocal, + np.array([[1, 2j, 2 + 3j]], dtype=dtype), + expected=1.0 / np.array([[1, 2j, 2 + 3j]], dtype=dtype)) + + if atan2_supported: + self._assertOpOutputMatchesExpected( + math_ops.log, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.log(np.array([[5j, 3 - 2j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.sin, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.sin(np.array([[5j, 3 - 2j]], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.cos, + np.array([[5j, 3 - 2j]], dtype=dtype), + expected=np.cos(np.array([[5j, 3 - 2j]], dtype=dtype))) + + # TODO(b/34703906): improve log1p implementation and make tolerance + # tighter. + if atan2_supported: # TODO(b/34703906): log support + self._assertOpOutputMatchesExpected( + math_ops.log1p, + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype), + expected=np.log1p( + np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype))) + + # TODO(b/34703906): math_ops.rsqrt (needs pow) + + # TODO(b/34703906): math_ops.sigmoid (needs tanh) + + # TODO(b/34703906): math_ops.sqrt (needs pow) + + self._assertOpOutputMatchesExpected( + math_ops.tan, + np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype), + expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype))) + + # TODO(b/34703906): math_ops.tanh (as itself) + + ctypes = {np.complex64: np.float32} + self._assertOpOutputMatchesExpected( + math_ops.abs, + np.array([[3 - 4j, -1j, np.inf]], dtype=dtype), + expected=np.array([[5, 1, np.inf]], dtype=ctypes[dtype])) + + self._assertOpOutputMatchesExpected( + math_ops.negative, + np.array([[-1 + 2j, -3j]], dtype=dtype), + expected=np.array([[1 - 2j, 3j]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.square, + np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype), + expected=np.array([[-2 - 3j, 3 + 4j, 5j]], dtype=dtype)**2) + + self._assertOpOutputMatchesExpected( + array_ops.zeros_like, + np.array([[4j, 3 - 2j], [2, -1j]], dtype=dtype), + expected=np.array([[0, 0], [0, 0]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + array_ops.ones_like, + np.array([[-4j, 3 + 2j], [2, -1j]], dtype=dtype), + expected=np.array([[1, 1], [1, 1]], dtype=dtype)) + + if atan2_supported: # TODO(b/34703906): atan2 support + self._assertOpOutputMatchesExpected( + math_ops.angle, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.angle( + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype))) + + self._assertOpOutputMatchesExpected( + math_ops.conj, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([1 - 3j, -4 - 7j, 2.7, 3j], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + math_ops.imag, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([3, 7, 0, -3], dtype=ctypes[dtype])) + + self._assertOpOutputMatchesExpected( + math_ops.real, + np.array([1 + 3j, -4 + 7j, 2.7, -3j], dtype=dtype), + expected=np.array([1, -4, 2.7, 0], dtype=ctypes[dtype])) + + def testIntOps(self): + for dtype in self.int_types: + self._assertOpOutputMatchesExpected( + bitwise_ops.invert, + np.array([0, -1, 1, 16, 42], dtype=dtype), + expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) + def testNumericOps(self): for dtype in self.numeric_types: self._assertOpOutputMatchesExpected( @@ -396,11 +524,14 @@ class UnaryOpsTest(XLATestCase): def testCast(self): shapes = [[], [4], [2, 3], [2, 0, 4]] - types = [dtypes.bool, dtypes.int32, dtypes.float32] + types = [dtypes.bool, dtypes.int32, dtypes.float32] + self.complex_tf_types for shape in shapes: for src_type in types: for dst_type in types: src = np.arange(np.prod(shape)).astype(src_type.as_numpy_dtype) + if src_type in self.complex_tf_types: + src += (np.arange(np.prod(shape)) * 2j).astype( + src_type.as_numpy_dtype) src = src.reshape(shape) dst = src.astype(dst_type.as_numpy_dtype) @@ -492,6 +623,77 @@ class UnaryOpsTest(XLATestCase): ], equality_test=self.ListsAreClose) + def testDepthToSpace(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + lambda x: array_ops.depth_to_space(x, block_size=2), + np.array([[[[1, 2, 3, 4]]]], dtype=dtype), + expected=np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x: array_ops.depth_to_space(x, block_size=2), + np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype), + expected=np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x: array_ops.depth_to_space(x, block_size=2), + np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype), + expected=np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype)) + + def testSpaceToDepth(self): + for dtype in self.numeric_types: + self._assertOpOutputMatchesExpected( + lambda x: array_ops.space_to_depth(x, block_size=2), + np.array([[[[1], [2]], + [[3], [4]]]], dtype=dtype), + expected=np.array([[[[1, 2, 3, 4]]]], dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x: array_ops.space_to_depth(x, block_size=2), + np.array([[[[1, 2, 3], [4, 5, 6]], + [[7, 8, 9], [10, 11, 12]]]], dtype=dtype), + expected=np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], + dtype=dtype)) + + self._assertOpOutputMatchesExpected( + lambda x: array_ops.space_to_depth(x, block_size=2), + np.array([[[[1], [2], [5], [6]], + [[3], [4], [7], [8]], + [[9], [10], [13], [14]], + [[11], [12], [15], [16]]]], dtype=dtype), + expected=np.array([[[[1, 2, 3, 4], + [5, 6, 7, 8]], + [[9, 10, 11, 12], + [13, 14, 15, 16]]]], dtype=dtype)) + + def _assertSoftplusMatchesExpected(self, features, dtype): + features = np.array(features, dtype=dtype) + zero = np.asarray(0).astype(dtype) + expected = np.logaddexp(zero, features) + self._assertOpOutputMatchesExpected( + nn_ops.softplus, features, expected=expected) + + def testSoftplus(self): + for dtype in self.float_types: + self._assertSoftplusMatchesExpected([[-2, 0, 8]], dtype) + self._assertSoftplusMatchesExpected( + [[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]], dtype) + log_eps = np.log(np.finfo(dtype).eps) + one = dtype(1) + ten = dtype(10) + self._assertSoftplusMatchesExpected([ + log_eps, log_eps - one, log_eps + one, log_eps - ten, + log_eps + ten, -log_eps, -log_eps - one, -log_eps + one, + -log_eps - ten, -log_eps + ten], dtype) + if __name__ == "__main__": googletest.main() diff --git a/tensorflow/compiler/tests/variable_ops_test.py b/tensorflow/compiler/tests/variable_ops_test.py index fdf3f9fb6ada762751f8639af29bec0b0d9a8b01..c50342dee45eba6ae54f01653ecc81ef096b547b 100644 --- a/tensorflow/compiler/tests/variable_ops_test.py +++ b/tensorflow/compiler/tests/variable_ops_test.py @@ -43,7 +43,7 @@ class VariableOpsTest(XLATestCase): # Regression test for a bug where computations with one non-constant # output and one variable update were mishandled. for dtype in self.numeric_types: - init = np.array([[1, 2], [3, 4]], dtype=dtype) + init = np.array([[1, 2j], [3, 4]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) @@ -51,82 +51,91 @@ class VariableOpsTest(XLATestCase): x = v.assign_add(p) with ops.control_dependencies([x]): y = v.read_value() - self.assertAllClose(np.array([[2, 3], [4, 5]], dtype=dtype), - sess.run(y, {p: 1})) + self.assertAllClose( + np.array([[2, 1 + 2j], [4, 5]]).astype(dtype), sess.run(y, { + p: 1 + })) def testSparseRead0DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8j, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read(2) - self.assertAllClose(np.array([8, 9, 10, 11], dtype=dtype), sess.run(x)) + self.assertAllClose( + np.array([8j, 9, 10, 11]).astype(dtype), sess.run(x)) def testSparseRead1DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2, 3], [4, 5, 6j, 7], [8, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([2, 1]) self.assertAllClose( - np.array([[8, 9, 10, 11], [4, 5, 6, 7]], dtype=dtype), sess.run(x)) + np.array([[8, 9, 10, 11], [4, 5, 6j, 7]]).astype(dtype), + sess.run(x)) def testSparseRead2DIndices(self): for dtype in self.numeric_types: - init = np.array([[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11]], dtype=dtype) + init = np.array([[0, 1, 2j, 3], [4, 5, 6, 7], [8, 9, 10, + 11]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [0, 2]]) self.assertAllClose( - np.array( - [[[8, 9, 10, 11], [4, 5, 6, 7]], [[0, 1, 2, 3], [8, 9, 10, - 11]]], - dtype=dtype), sess.run(x)) + np.array([[[8, 9, 10, 11], [4, 5, 6, 7]], + [[0, 1, 2j, 3], [8, 9, 10, 11]]]).astype(dtype), + sess.run(x)) def testSparseRead2DIndices3DTensor(self): for dtype in self.numeric_types: - init = np.array( - [[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], - [[20, 21, 22], [23, 24, 25]], [[30, 31, 32], [33, 34, 35]]], - dtype=dtype) + init = np.array([[[0, 1, 2], [3, 4, 5]], [[10, 11, 12], [13, 14, 15]], + [[20, 21, 22], [23, 24j, 25]], + [[30, 31, 32], [33, 34, 35]]]).astype(dtype) with self.test_session() as sess, self.test_scope(): v = resource_variable_ops.ResourceVariable(init) sess.run(variables.variables_initializer([v])) x = v.sparse_read([[2, 1], [3, 0]]) self.assertAllClose( np.array( - [[[[20, 21, 22], [23, 24, 25]], [[10, 11, 12], [13, 14, 15]]], + [[[[20, 21, 22], [23, 24j, 25]], [[10, 11, 12], [13, 14, 15]]], [[[30, 31, 32], [33, 34, 35]], [[0, 1, 2], [3, 4, 5]]]], - dtype=dtype), sess.run(x)) + ).astype(dtype), sess.run(x)) def testReadWrite(self): """Tests initialization, reading, and writing a resource variable.""" - with self.test_session() as session: - with self.test_scope(): - with variable_scope.variable_scope("ascope", use_resource=True): - x = variable_scope.get_variable( - "x", - shape=[], - dtype=dtypes.float32, - initializer=init_ops.constant_initializer(2)) - a = x.read_value() - with ops.control_dependencies([a]): - b = state_ops.assign(x, 47) - with ops.control_dependencies([b]): - c = x.read_value() - with ops.control_dependencies([c]): - d = state_ops.assign_add(x, 3) - with ops.control_dependencies([d]): - e = x.read_value() - - session.run(variables.global_variables_initializer()) - v1, v2, v3 = session.run([a, c, e]) - self.assertAllClose(2.0, v1) - self.assertAllClose(47.0, v2) - self.assertAllClose(50.0, v3) + for dtype in self.numeric_types: + with self.test_session() as session: + print(ops.get_default_graph()) + with self.test_scope(): + with variable_scope.variable_scope("ascope", use_resource=True): + x = variable_scope.get_variable( + "x", + shape=[], + dtype=dtype, + initializer=init_ops.constant_initializer(2)) + a = x.read_value() + with ops.control_dependencies([a]): + b = state_ops.assign(x, dtype(47)) + with ops.control_dependencies([b]): + c = x.read_value() + with ops.control_dependencies([c]): + d = state_ops.assign_add(x, np.array(6 + 2j).astype(dtype)) + with ops.control_dependencies([d]): + e = state_ops.assign_sub(x, dtype(3)) + with ops.control_dependencies([e]): + f = x.read_value() + + session.run(variables.global_variables_initializer()) + v1, v2, v3 = session.run([a, c, f]) + self.assertAllClose(dtype(2), v1) + self.assertAllClose(dtype(47), v2) + self.assertAllClose(np.array(50 + 2j).astype(dtype), v3) def testTraining(self): """Tests a gradient descent step for a simple model.""" diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index da6dc88f1fb07200799f8ee231fc04628b265e24..0be127997e5211f810ca791187486760881fe172 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -63,12 +63,19 @@ class XLATestCase(test.TestCase): self.float_tf_types = [ dtype for dtype in self.all_tf_types if dtype.is_floating ] - self.numeric_tf_types = self.int_tf_types + self.float_tf_types + self.complex_tf_types = [ + dtype for dtype in self.all_tf_types if dtype.is_complex + ] + self.numeric_tf_types = ( + self.int_tf_types + self.float_tf_types + self.complex_tf_types) self.all_types = [dtype.as_numpy_dtype for dtype in self.all_tf_types] self.int_types = [dtype.as_numpy_dtype for dtype in self.int_tf_types] self.float_types = [dtype.as_numpy_dtype for dtype in self.float_tf_types] - self.numeric_types = self.int_types + self.float_types + self.complex_types = [ + dtype.as_numpy_dtype for dtype in self.complex_tf_types + ] + self.numeric_types = self.int_types + self.float_types + self.complex_types # Parse the manifest file, if any, into a regex identifying tests to # disable diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 0769b13718821f7528b7e2620e50787e55aa20f6..912e819d8d63886c663aaabd3cbe3bd76a1ced07 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -58,6 +58,42 @@ cc_library( ], ) +cc_library( + name = "xla_compiled_cpu_function", + srcs = ["xla_compiled_cpu_function.cc"], + hdrs = ["xla_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + # Keep dependencies to a minimum here; this library is used in every AOT + # binary produced by tfcompile. + "//tensorflow/compiler/aot:runtime", + "//tensorflow/compiler/tf2xla:xla_local_runtime_context", + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:framework_lite", + ], +) + +cc_library( + name = "xla_jit_compiled_cpu_function", + srcs = ["xla_jit_compiled_cpu_function.cc"], + hdrs = ["xla_jit_compiled_cpu_function.h"], + visibility = ["//visibility:public"], + deps = [ + ":tf2xla", + ":tf2xla_proto", + ":xla_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:client_library", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:cpu_plugin", + "//tensorflow/compiler/xla/service/cpu:cpu_executable", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + cc_library( name = "xla_compiler", srcs = [ @@ -67,11 +103,13 @@ cc_library( "xla_helpers.cc", "xla_op_kernel.cc", "xla_op_registry.cc", + "graph_compiler.cc", "xla_cpu_backend.cc", ] + if_cuda_is_configured([ "xla_gpu_backend.cc", ]), hdrs = [ + "graph_compiler.h", "xla_compilation_device.h", "xla_compiler.h", "xla_context.h", @@ -82,8 +120,11 @@ cc_library( visibility = [":friends"], deps = [ ":common", + ":const_analysis", ":dump_graph", ":functionalize_control_flow", + ":sharding_util", + ":tf2xla_util", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", @@ -130,6 +171,36 @@ cc_library( ], ) +cc_library( + name = "sharding_util", + srcs = ["sharding_util.cc"], + hdrs = ["sharding_util.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla/client:computation_builder", + "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:framework", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + ], +) + +tf_cc_test( + name = "sharding_util_test", + srcs = ["sharding_util_test.cc"], + deps = [ + ":sharding_util", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + # Internal targets below this point. cc_library( @@ -137,7 +208,9 @@ cc_library( srcs = ["tf2xla_util.cc"], hdrs = ["tf2xla_util.h"], deps = [ + ":sharding_util", ":tf2xla_proto", + "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", @@ -151,8 +224,14 @@ tf_cc_test( name = "tf2xla_util_test", srcs = ["tf2xla_util_test.cc"], deps = [ + ":sharding_util", ":tf2xla_util", + "//tensorflow/cc:cc_ops", + "//tensorflow/cc:function_ops", + "//tensorflow/cc:ops", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", + "//tensorflow/core:math_ops_op_lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -178,6 +257,25 @@ tf_cc_test( ], ) +tf_cc_test( + name = "xla_jit_compiled_cpu_function_test", + srcs = ["xla_jit_compiled_cpu_function_test.cc"], + deps = [ + ":tf2xla_proto", + ":xla_jit_compiled_cpu_function", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "xla_compiler_test", srcs = ["xla_compiler_test.cc"], @@ -198,6 +296,7 @@ tf_cc_test( "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", + "//tensorflow/core:testlib", ], ) @@ -291,7 +390,9 @@ cc_library( srcs = ["functionalize_control_flow.cc"], hdrs = ["functionalize_control_flow.h"], deps = [ + ":tf2xla_util", "//tensorflow/compiler/jit:graph_to_functiondef", + "//tensorflow/compiler/jit:union_find", "//tensorflow/compiler/tf2xla:dump_graph", "//tensorflow/compiler/tf2xla/ops:functional_ops", "//tensorflow/compiler/xla:status_macros", @@ -299,6 +400,7 @@ cc_library( "//tensorflow/core:core_cpu", "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", + "//tensorflow/core:lib", ], ) @@ -316,6 +418,7 @@ tf_cc_test( "//tensorflow/compiler/tf2xla/cc:functional_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/core:core_cpu", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", "//tensorflow/core:ops", diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc index e4e1689a2de5780525a1e20c6a22911633845fdf..d57273d84442c17565a6ace1c29170a0f3ba583b 100644 --- a/tensorflow/compiler/tf2xla/const_analysis.cc +++ b/tensorflow/compiler/tf2xla/const_analysis.cc @@ -33,12 +33,15 @@ Status BackwardsConstAnalysis(const Graph& g, const std::unordered_multimap compile_time_const_inputs = { {"All", "reduction_indices"}, {"Any", "reduction_indices"}, + {"ArgMin", "dimension"}, {"ArgMax", "dimension"}, {"AvgPoolGrad", "orig_input_shape"}, {"AvgPool3DGrad", "orig_input_shape"}, {"BatchToSpace", "crops"}, {"BatchToSpaceND", "block_shape"}, {"BatchToSpaceND", "crops"}, + {"BroadcastArgs", "s0"}, + {"BroadcastArgs", "s1"}, {"BroadcastGradientArgs", "s0"}, {"BroadcastGradientArgs", "s1"}, {"Concat", "concat_dim"}, @@ -54,6 +57,7 @@ Status BackwardsConstAnalysis(const Graph& g, {"DynamicStitch", "indices"}, {"ExpandDims", "dim"}, {"Fill", "dims"}, + {"GatherV2", "axis"}, {"InvertPermutation", "x"}, {"LinSpace", "start"}, {"LinSpace", "stop"}, @@ -63,7 +67,9 @@ Status BackwardsConstAnalysis(const Graph& g, {"Min", "reduction_indices"}, {"OneHot", "depth"}, {"Pad", "paddings"}, + {"PadV2", "paddings"}, {"MirrorPad", "paddings"}, + {"Multinomial", "num_samples"}, {"Prod", "reduction_indices"}, {"RandomStandardNormal", "shape"}, {"RandomUniform", "shape"}, diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc index 1c7a2046aa549beb2de58d21f517363d4fe8aea7..6ef4860f35835e59be3452b57204d42c82d0816b 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc @@ -17,15 +17,20 @@ limitations under the License. #include #include +#include #include #include #include "tensorflow/compiler/jit/graph_to_functiondef.h" +#include "tensorflow/compiler/jit/union_find.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/graph/control_flow.h" +#include "tensorflow/core/lib/gtl/optional.h" namespace tensorflow { @@ -70,11 +75,24 @@ struct Frame { std::unordered_set nodes; }; +// Returns a textual representation of the names of the nodes in the input. +template +string NodesToString(const T& nodes) { + return strings::StrCat("{", + str_util::Join(nodes, ",", + [](string* output, const Node* node) { + strings::StrAppend(output, + node->name()); + }), + "}"); +} + // Copies a subgraph from `graph` to `output` by performing a reverse DFS // starting at nodes in vector `stack`. // `node_map` is a vector indexed by source node ID to dest nodes. // Does not traverse into nodes in `node_map`, so by adding nodes to `node_map` -// before the traversal clients can cut the graph. Returns an error if the +// before the traversal clients can cut the graph. If a frame is provided (frame +// != nullptr), then this functions will return an error if the // traversal leaves 'frame'; the client must add enough nodes to `node_map` to // cut the graph and prevent the traversal from escaping. // @@ -84,25 +102,26 @@ struct Frame { // taking from the Switch node was not necessarily the first output, but _Arg // nodes only have one output. By adding the Switch node to `squash_src_outputs` // we rewrite the src_output of the corresponding edge to be 0. -Status CopySubgraph(const Graph& graph, const Frame& frame, +Status CopySubgraph(const Graph& graph, const Frame* frame, std::vector stack, const std::vector& squash_src_outputs, std::vector* node_map, Graph* output) { + VLOG(3) << "Stack: " << NodesToString(stack); std::vector visited(graph.num_node_ids(), false); while (!stack.empty()) { Node* n = stack.back(); stack.pop_back(); - VLOG(3) << "Copying node " << n->name(); + VLOG(5) << "Copying node " << n->name(); if (visited[n->id()]) continue; visited[n->id()] = true; for (const Edge* e : n->in_edges()) { Node* src = e->src(); - if (frame.nodes.find(src) == frame.nodes.end()) { + if (frame != nullptr && frame->nodes.find(src) == frame->nodes.end()) { // We traversed out of the loop frame, without encountering a cut node. - return errors::Internal("Graph traversal of loop frame ", frame.name, + return errors::Internal("Graph traversal of loop frame ", frame->name, " escaped frame at ", src->name(), " without encountering an argument node."); } @@ -111,7 +130,9 @@ Status CopySubgraph(const Graph& graph, const Frame& frame, stack.push_back(src); } Node* src_copy = (*node_map)[e->src()->id()]; - int src_output = squash_src_outputs[e->src()->id()] ? 0 : e->src_output(); + int src_output = squash_src_outputs[e->src()->id()] && !e->IsControlEdge() + ? 0 + : e->src_output(); Node* dst_copy = (*node_map)[e->dst()->id()]; output->AddEdge(src_copy, src_output, dst_copy, e->dst_input()); } @@ -119,27 +140,31 @@ Status CopySubgraph(const Graph& graph, const Frame& frame, return Status::OK(); } -Status BuildArgNode(Graph* graph, DataType type, int index, Node** arg_node) { +xla::StatusOr AddNode(const NodeDef& node_def, Graph* graph) { + Status status; + Node* inserted_node = graph->AddNode(node_def, &status); + if (!status.ok()) { + return status; + } + return inserted_node; +} + +xla::StatusOr BuildArgNode(Graph* graph, DataType type, int index) { NodeDef arg_def; - NodeDefBuilder builder(strings::StrCat("_Arg", index), kArgOp); + NodeDefBuilder builder(strings::StrCat(kArgOp, index), kArgOp); builder.Attr("T", type); builder.Attr("index", index); TF_RETURN_IF_ERROR(builder.Finalize(&arg_def)); - Status status; - *arg_node = graph->AddNode(arg_def, &status); - return status; + return AddNode(arg_def, graph); } -Status BuildRetvalNode(Graph* graph, DataType type, int index, - Node** retval_node) { +xla::StatusOr BuildRetvalNode(Graph* graph, DataType type, int index) { NodeDef ret_def; ret_def.set_op(kRetValOp); - ret_def.set_name(strings::StrCat("_Retval", index)); + ret_def.set_name(strings::StrCat(kRetValOp, index)); AddNodeAttr("T", type, &ret_def); AddNodeAttr("index", index, &ret_def); - Status status; - *retval_node = graph->AddNode(ret_def, &status); - return status; + return AddNode(ret_def, graph); } // Builds a graph for the loop condition. @@ -157,9 +182,8 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, for (int i = 0; i < frame->args.size(); ++i) { const Arg& arg = frame->args[i]; - Node* arg_node; - TF_RETURN_IF_ERROR( - BuildArgNode(output, arg.enter->input_type(0), i, &arg_node)); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(output, arg.enter->input_type(0), i)); if (arg.is_loop_invariant) { node_map[arg.enter->id()] = arg_node; } else { @@ -169,16 +193,14 @@ Status BuildLoopCondition(const Graph& graph, Frame* frame, // Build a Retval node for the loop condition. The LoopCond nodes are always // boolean because of the type constraints on the LoopCond op. - TF_RETURN_IF_ERROR( - BuildRetvalNode(output, DT_BOOL, 0, &node_map[frame->loop_cond->id()])); + TF_ASSIGN_OR_RETURN(node_map[frame->loop_cond->id()], + BuildRetvalNode(output, DT_BOOL, 0)); // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, {frame->loop_cond}, - squash_src_outputs, &node_map, output)); - - return Status::OK(); + return CopySubgraph(graph, frame, {frame->loop_cond}, squash_src_outputs, + &node_map, output); } // Builds a graph for the loop body. @@ -202,8 +224,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, DataType dtype = arg.enter->input_type(0); arg_types->push_back(dtype); - Node* arg_node; - TF_RETURN_IF_ERROR(BuildArgNode(output, dtype, i, &arg_node)); + + TF_ASSIGN_OR_RETURN(Node * arg_node, BuildArgNode(output, dtype, i)); if (dtype == DT_RESOURCE) { // The convention of the XLA bridge is that resource variable arguments @@ -213,8 +235,8 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, TF_RET_CHECK(arg.is_loop_invariant); node_map[arg.enter->id()] = arg_node; } else { - Node* retval_node; - TF_RETURN_IF_ERROR(BuildRetvalNode(output, dtype, i, &retval_node)); + TF_ASSIGN_OR_RETURN(Node * retval_node, + BuildRetvalNode(output, dtype, i)); if (arg.is_loop_invariant) { // Argument is loop-invariant. Forward it from the Arg to the Retval. @@ -237,7 +259,7 @@ Status BuildLoopBody(const Graph& graph, Frame* frame, // Performs a reverse DFS, copying nodes and edges to the output graph. // The _Arg and _Retval nodes were added unconditionally above, so we are // guaranteed to get the correct function signature. - TF_RETURN_IF_ERROR(CopySubgraph(graph, *frame, std::move(next_iterations), + TF_RETURN_IF_ERROR(CopySubgraph(graph, frame, std::move(next_iterations), squash_src_outputs, &node_map, output)); return Status::OK(); @@ -386,7 +408,15 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.merge->name()); } - // Find the Exit successor of the Switch. + // Update the device on the Identity outputs of the switch to match their + // target. These Identity outputs do not + + // Loop over the switch node's output to: + // - Find the Exit successor. + // - Set the sharding on all Identity outputs of the switch. These + // identity nodes are values used by the loop body or condition. + // The Identity node may have the wrong device so copy the device from + // one of its outputs instead. for (const Edge* edge : arg.switch_node->out_edges()) { if (edge->src_output() == 0 && IsExit(edge->dst())) { if (arg.exit != nullptr) { @@ -394,12 +424,11 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, arg.switch_node->name()); } arg.exit = edge->dst(); + } else if (StringPiece(edge->dst()->type_string()) == "Identity") { + TF_RETURN_IF_ERROR( + SetNodeShardingFromNeighbors(edge->dst(), /*out_edges=*/true)); } } - if (arg.exit == nullptr) { - return errors::InvalidArgument("Missing Exit successor to ", - arg.switch_node->name()); - } } } @@ -450,12 +479,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } builder.Input(inputs); TF_RETURN_IF_ERROR(builder.Finalize(&while_def)); - - Status status; - Node* while_node = graph->AddNode(while_def, &status); - if (!status.ok()) { - return status; - } + TF_ASSIGN_OR_RETURN(Node * while_node, AddNode(while_def, graph)); // Copies edges to the Enter nodes and from the Exit nodes onto the While. for (int i = 0; i < frame->args.size(); ++i) { @@ -469,16 +493,21 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, } if (!arg.is_loop_invariant) { - std::vector edges(arg.exit->out_edges().begin(), - arg.exit->out_edges().end()); - for (const Edge* edge : edges) { - Node* dst = edge->dst(); - int dst_input = edge->dst_input(); - graph->RemoveEdge(edge); - - int src_output = - dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; - graph->AddEdge(while_node, src_output, dst, dst_input); + // Add output edges if the output of the loop is consumed. + if (arg.exit != nullptr) { + std::vector edges(arg.exit->out_edges().begin(), + arg.exit->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + graph->RemoveEdge(edge); + + if (dst_input == Graph::kControlSlot) { + graph->AddControlEdge(while_node, dst); + } else { + graph->AddEdge(while_node, i, dst, dst_input); + } + } } } } @@ -488,6 +517,7 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, for (Node* node : frame->nodes) { graph->RemoveNode(node); } + frame->nodes.clear(); frame->parent->nodes.insert(while_node); VLOG(2) << "Frame " << frame->name << " after: " @@ -496,13 +526,863 @@ Status FunctionalizeLoop(Graph* graph, Frame* frame, return Status::OK(); } +class FunctionalizeCond { + public: + // Identifies the connected parts of the tf.Cond. + struct ClusterHandle { + explicit ClusterHandle(int representative = -1) + : representative(representative) {} + + bool operator==(const ClusterHandle& other) const { + return representative == other.representative; + } + + bool operator!=(const ClusterHandle& other) const { + return !(*this == other); + } + + bool operator<(const ClusterHandle& other) const { + return representative < other.representative; + } + + bool operator>(const ClusterHandle& other) const { + return representative > other.representative; + } + + string ToString() const { + return strings::StrCat("Cluster_", representative); + } + + // Vector of UnionFind indexable by ClusterHandle and Node*. + struct Vector { + explicit Vector(size_t size) : clusters(size) {} + + UnionFind& at(const ClusterHandle& cluster) { + return clusters.at(cluster.representative); + } + + UnionFind& at(const Node* node) { + return clusters.at(node->id()); + } + + UnionFind& operator[](const Node* node) { + return clusters.at(node->id()); + } + + size_t size() const { return clusters.size(); } + + void resize(size_t count) { return clusters.resize(count); } + + private: + std::vector> clusters; + }; + + private: + int representative; + }; + + // Represents a node in the clustered graph consisting of switch_nodes, + // merge_nodes as well as the edges into and out of this node to other + // Clusters. Each Cluster corresponds to a ClusterHandle and has a + // corresponding representative. + struct Cluster { + std::unordered_set switch_nodes; + std::unordered_set merge_nodes; + std::unordered_set in_nodes; + std::unordered_set out_nodes; + + // A member of the ClusterHandle corresponding to this Cluster. + ClusterHandle representative; + bool visited = false; + }; + + // Represent the clustered graph as map from cluster representative to + // Cluster. + using ClusteredGraph = std::map; + + // The arguments and condition of a XlaIf. The arguments are ordered by node + // id in the original graph. + struct CondArgs { + struct CondCmp { + bool operator()(const Node* lhs, const Node* rhs) const { + bool lhs_is_resource = + lhs->num_inputs() > 0 ? (lhs->input_type(0) == DT_RESOURCE) : false; + bool rhs_is_resource = + rhs->num_inputs() > 0 ? (rhs->input_type(0) == DT_RESOURCE) : false; + return std::tie(lhs_is_resource, lhs->name()) < + std::tie(rhs_is_resource, rhs->name()); + } + }; + Node* conditional = nullptr; + std::set args; + }; + + static Status Functionalize(Graph* graph, FunctionLibraryDefinition* library); + + private: + FunctionalizeCond(Graph* graph, FunctionLibraryDefinition* library) + : clusters_(graph->num_node_ids()), library_(library), graph_(graph) {} + + // Returns a vector of Merge nodes from the clustered graph where the nodes + // are sorted by the number of switch nodes minus number of merge nodes + // from a root of the clustered graph to the given Merge node, with ties + // broken by the representative of the Cluster. + std::vector> SortedMergeNodes(); + + // Returns whether the graph has no conditionals. + bool NoConditionals() const { return merge_nodes_.empty(); } + + // Construct the clustered graph by creating nodes for each cluster and the + // connections between the clusters. Switch and Merge nodes partition + // clusters, so iterate over those. Note: a Cluster may have neither a + // Merge or Switch but will have an in/out edge from a Cluster that has. + void CreateClusters(); + + // Creates the clustered graph by identifying all the edges between different + // clusters and collecting all switch and merge nodes that correspond to a + // cluster. + void CreateClusteredGraph(); + + // If `from` and `to` correspond to different clusters, then merge the nodes + // in the clustered graph corresponding to `from` and `to`. + // + // If `remove_from_graph` is specified then the `from` node is also removed + // from the clustered graph post contracting the edge. + void ContractEdge(Cluster* from, Cluster* to, bool remove_from_graph = false); + + // Converts a Merge node to a XlaIf. This encapsulates the process of + // extracting the bodies needed for the then and else branch, creates a XlaIf + // node, removing the nodes of the branches from the graph and replacing the + // merge node with a XlaIf. + Status ConvertMergeToXlaIf(Cluster* merge_cluster); + + // Removes a Switch cluster feeding directly into a Merge cluster by removing + // the Switch and Merge nodes and collapsing into a single cluster. + Status RemoveTrivialMerge(Cluster* merge_cluster); + + // Returns the switch cluster corresponding to the merge node. This function + // only returns the switch cluster in the simple case where we have a switch + // node is the entry of a diamond corresponding to a conditional: + // + // Switch + // / \ + // Branch Branch + // \ / + // merge_cluster + // + // Note: either of the branches may be empty. The case where both branches are + // empty is handled by RemoveTrivialMerge. + gtl::optional GetSwitchCluster(const Cluster& merge_cluster); + + // Determines the arguments needed as input to the Merge cluster originating + // from the Switch cluster. + xla::StatusOr DetermineCondArgs(const Cluster& merge_cluster, + const Cluster& switch_cluster); + + // Builds a XlaIfOp to replace the Merge node with. + xla::StatusOr BuildAndAddXlaIfOp(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs); + + // Extracts a function body corresponding to the given input edge of the merge + // node. + Status ExtractBody(const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs, int input_edge, + Graph* body); + + // Adds all the input edges to `if_node` corresponding to the arguments. + Status AddInputEdges(const CondArgs& cond_args, Node* if_node); + + // Adds all output edges from the `if_node`. + Status AddOutputEdges(const std::vector& outputs, Node* if_node); + + // Removes all nodes from the graph that are part of cluster. + void RemoveClusterNodes(Cluster* cluster); + + // Removes all argument nodes that are unused. + template + void RemoveUnusedArgs(const T& args); + + // Removes all Merge nodes in merge_cluster. + void RemoveMergeNodes(Cluster* merge_cluster); + + // Returns the representative member of the corresponding cluster. + ClusterHandle Representative(const Node* node) { + return clusters_.at(node).Get(); + } + + ClusteredGraph clustered_graph_; + ClusterHandle::Vector clusters_; + std::unordered_set merge_nodes_; + std::unordered_set switch_nodes_; + FunctionLibraryDefinition* library_; + Graph* graph_; +}; + +std::ostream& operator<<(std::ostream& os, + const FunctionalizeCond::ClusterHandle& c) { + os << c.ToString(); + return os; +} + +// Returns a dot representation of the clustered graph showing the connections +// between the nodes and the nodes in each cluster. +string DebugString(const Graph& graph, + FunctionalizeCond::ClusterHandle::Vector* clusters) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";ranksep=0.24\n"; + std::map subgraphs; + for (Node* n : graph.nodes()) { + if (n->IsOp()) { + strings::StrAppend(&subgraphs[clusters->at(n).Get()], n->id(), + " [label=\"", n->name(), "\"];\n"); + } + } + for (auto kv : subgraphs) { + strings::StrAppend(&ret, "subgraph cluster_", kv.first.ToString(), " {\n", + "style=filled; color=lightgrey;", "label = \"", + kv.first.ToString(), "\";\n", kv.second, "}\n"); + } + for (Node* n : graph.nodes()) { + if (!n->IsOp()) { + continue; + } + for (Node* in : n->in_nodes()) { + if (in->IsOp()) { + strings::StrAppend(&ret, in->id(), " -> ", n->id(), ";\n"); + } + } + } + return strings::StrCat(ret, "}"); +} + +string DebugString(const FunctionalizeCond::ClusteredGraph& clustered_graph) { + string ret = "digraph {\ncompound=true;labeljust=\"r\";\n"; + auto name = [](const FunctionalizeCond::Cluster& cluster) { + return cluster.representative.ToString(); + }; + for (auto kv : clustered_graph) { + strings::StrAppend(&ret, kv.first.ToString(), " [label=\"", name(kv.second), + " (", kv.second.switch_nodes.size(), ", ", + kv.second.merge_nodes.size(), ")\"];\n"); + } + for (auto kv : clustered_graph) { + for (auto in : kv.second.in_nodes) { + strings::StrAppend(&ret, name(*in), " -> ", name(kv.second), ";\n"); + } + } + return strings::StrCat(ret, "}"); +} + +bool IsDeadSwitch(const Node* node) { + for (const Edge* e : node->out_edges()) { + const Node* dst = e->dst(); + if (!dst->IsIdentity()) { + return false; + } + for (const Edge* ee : dst->out_edges()) { + if (!ee->IsControlEdge() || !ee->dst()->IsSink()) { + return false; + } + } + } + return true; +} + +void FunctionalizeCond::CreateClusters() { + for (Node* node : graph_->nodes()) { + if (!node->IsOp()) { + continue; + } + if (IsSwitch(node)) { + switch_nodes_.insert(node); + } else if (IsMerge(node)) { + merge_nodes_.insert(node); + } + ClusterHandle& cluster = clusters_.at(node).Get(); + cluster = ClusterHandle(node->id()); + } + + // If there are no Merge nodes, then terminate. + if (merge_nodes_.empty()) { + return; + } + + // Remove all dead Switch nodes. + RemoveUnusedArgs(switch_nodes_); + + // All parent_'s are still nullptr so clusters_ may still be resized. Resize + // conservatively assuming all merge nodes become XlaIf nodes. + clusters_.resize(clusters_.size() + merge_nodes_.size()); + + // Merge a cluster with its input, unless the input is a Switch node or + // the node is a Merge node. + for (const Node* node : graph_->nodes()) { + if (IsMerge(node) || IsSwitch(node) || !node->IsOp()) { + continue; + } + for (const Node* in : node->in_nodes()) { + if (in->IsOp() && !IsSwitch(in) && !IsMerge(in)) { + clusters_.at(node).Merge(&clusters_.at(in)); + } + } + } +} + +void FunctionalizeCond::ContractEdge(Cluster* from, Cluster* to, + bool remove_from_graph) { + VLOG(3) << "ContractEdge from = " << from->representative + << " to = " << to->representative; + if (from->representative == to->representative) { + return; + } + to->merge_nodes.insert(from->merge_nodes.begin(), from->merge_nodes.end()); + from->merge_nodes.clear(); + to->switch_nodes.insert(from->switch_nodes.begin(), from->switch_nodes.end()); + from->switch_nodes.clear(); + + for (Cluster* from_out : from->out_nodes) { + from_out->in_nodes.erase(from); + if (from_out->representative != to->representative) { + from_out->in_nodes.insert(to); + to->out_nodes.insert(from_out); + } + } + from->out_nodes.clear(); + + for (Cluster* from_in : from->in_nodes) { + from_in->out_nodes.erase(from); + if (from_in->representative != to->representative) { + from_in->out_nodes.insert(to); + to->in_nodes.insert(from_in); + } + } + from->in_nodes.clear(); + + to->in_nodes.erase(from); + to->out_nodes.erase(from); + clusters_.at(to->representative).Merge(&clusters_.at(from->representative)); + from->visited = true; + + if (remove_from_graph) { + clustered_graph_.erase(from->representative); + } +} + +void FunctionalizeCond::CreateClusteredGraph() { + auto update_cluster_for_node = [this](Node* node) -> Cluster& { + ClusterHandle repr = Representative(node); + Cluster& cluster_node = clustered_graph_[repr]; + cluster_node.representative = repr; + for (const Node* in : node->in_nodes()) { + ClusterHandle other_repr = Representative(in); + // Skip source, sink and internal edges. + if (!in->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_in = clustered_graph_[other_repr]; + cluster_node.in_nodes.insert(&cluster_node_in); + cluster_node_in.out_nodes.insert(&cluster_node); + cluster_node_in.representative = other_repr; + } + for (const Node* out : node->out_nodes()) { + ClusterHandle other_repr = Representative(out); + // Skip source, sink and internal edges. + if (!out->IsOp() || other_repr == repr) { + continue; + } + Cluster& cluster_node_out = clustered_graph_[other_repr]; + cluster_node.out_nodes.insert(&cluster_node_out); + cluster_node_out.in_nodes.insert(&cluster_node); + cluster_node_out.representative = other_repr; + } + return cluster_node; + }; + for (Node* node : switch_nodes_) { + update_cluster_for_node(node).switch_nodes.insert(node); + } + for (Node* node : merge_nodes_) { + update_cluster_for_node(node).merge_nodes.insert(node); + } + + // Merge Switch nodes with common predicate. + std::unordered_map> predicate_to_switch; + for (Node* node : switch_nodes_) { + Node* tmp; + TF_CHECK_OK(node->input_node(1, &tmp)); + predicate_to_switch[tmp].push_back(node); + } + for (auto kv : predicate_to_switch) { + Cluster& first = clustered_graph_.at(Representative(kv.second.front())); + for (Node* switch_node : kv.second) { + ClusterHandle handle = Representative(switch_node); + Cluster& cluster = clustered_graph_.at(handle); + ContractEdge(&cluster, &first, /*remove_from_graph=*/true); + } + } + + // Merge Merge nodes with common input together. + for (Node* node : merge_nodes_) { + Cluster& cluster = clustered_graph_.at(Representative(node)); + for (const Node* in : node->in_nodes()) { + if (!in->IsOp()) { + continue; + } + Cluster& cluster_node_in = clustered_graph_.at(Representative(in)); + // ContractEdge can modify out_nodes of cluster_node_in, so traverse + // over out_nodes assuming it does. + for (auto it = cluster_node_in.out_nodes.begin(); + it != cluster_node_in.out_nodes.end();) { + if (!(*it)->merge_nodes.empty()) { + ContractEdge(*it++, &cluster, /*remove_from_graph=*/true); + } else { + ++it; + } + } + } + } + + VLOG(3) << "Graph with clusters: " << DebugString(*graph_, &clusters_); + VLOG(3) << "ClusteredGraph: " << DebugString(clustered_graph_); +} + +gtl::optional FunctionalizeCond::GetSwitchCluster( + const Cluster& merge_cluster) { + VLOG(3) << "GetSwitchCluster for " << merge_cluster.representative; + gtl::optional switch_cluster; + if (merge_cluster.in_nodes.size() > 2) { + return gtl::nullopt; + } + for (Cluster* in : merge_cluster.in_nodes) { + Cluster* cluster = in; + if (in->switch_nodes.empty()) { + if (in->in_nodes.size() != 1) { + return gtl::nullopt; + } + // There is only a single `in` cluster. + cluster = *in->in_nodes.begin(); + } + if (cluster->switch_nodes.empty()) { + return gtl::nullopt; + } + + if (switch_cluster.has_value() && *switch_cluster != cluster) { + return gtl::nullopt; + } else { + switch_cluster = cluster; + } + } + return switch_cluster; +} + +xla::StatusOr FunctionalizeCond::DetermineCondArgs( + const Cluster& merge_cluster, const Cluster& switch_cluster) { + VLOG(2) << "DetermineCondArgs for " << merge_cluster.representative + << " with switch cluster " << switch_cluster.representative; + CondArgs ret; + auto feeds_into_branch_cluster = [&](Node* switch_cluster) { + for (Node* out : switch_cluster->out_nodes()) { + ClusterHandle repr = Representative(out); + if (repr == merge_cluster.representative) { + return true; + } + for (Cluster* in : merge_cluster.in_nodes) { + if (repr == in->representative) { + return true; + } + } + } + return false; + }; + for (Node* switch_cluster_node : switch_cluster.switch_nodes) { + if (!feeds_into_branch_cluster(switch_cluster_node)) { + continue; + } + + Node* tmp; + TF_RETURN_IF_ERROR(switch_cluster_node->input_node(1, &tmp)); + if (ret.conditional == nullptr) { + ret.conditional = tmp; + } else if (ret.conditional != tmp) { + return errors::Unimplemented( + "Switch statements with different conditionals cannot be " + "converted into functional conditional."); + } + ret.args.insert(switch_cluster_node); + } + return ret; +} + +xla::StatusOr FunctionalizeCond::BuildAndAddXlaIfOp( + const CondArgs& cond_args, const Cluster& merge_cluster, + const std::vector& outputs) { + VLOG(2) << "Build if op for " << NodesToString(merge_cluster.merge_nodes) + << " with input " << NodesToString(cond_args.args); + + NodeDef if_def; + // Create a new If node using the name of the merge node. + NodeDefBuilder builder( + strings::StrCat((*merge_cluster.merge_nodes.begin())->name(), "_If"), + "XlaIf"); + string branch[] = {"else_branch", "then_branch"}; + for (int i = 0; i < 2; ++i) { + static std::atomic sequence_num(0LL); + int64 id = ++sequence_num; + + NameAttrList body_name; + body_name.set_name( + strings::StrCat("_functionalize_if_", branch[i], "_", id)); + auto body = xla::MakeUnique(graph_->op_registry()); + TF_RETURN_IF_ERROR( + ExtractBody(cond_args, merge_cluster, outputs, i, body.get())); + VLOG(3) << "Body " << branch[i] << ": " << DebugString(body.get()); + FunctionDef body_fdef; + TF_RETURN_IF_ERROR(GraphToFunctionDef(*body, body_name.name(), &body_fdef)); + TF_RETURN_IF_ERROR(library_->AddFunctionDef(body_fdef)); + builder.Attr(branch[i], body_name); + } + + // Build input type. + std::vector inputs; + DataTypeVector in_arg_types; + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + builder.ControlInput(in_edge->src()->name()); + } else { + DataType dtype = arg->input_type(0); + inputs.emplace_back(NodeDefBuilder::NodeOut( + in_edge->src()->name(), in_edge->src_output(), dtype)); + in_arg_types.push_back(dtype); + } + } + builder.Attr("Tin", in_arg_types); + + // Build output type. + DataTypeVector out_type; + for (const Node* merge : merge_cluster.merge_nodes) { + DataType dtype = merge->output_type(0); + out_type.push_back(dtype); + } + builder.Attr("Tout", out_type); + + builder.Attr("Tcond", DT_BOOL); + builder.Device(cond_args.conditional->assigned_device_name()); + // Conditional should be the first input ... + builder.Input(NodeDefBuilder::NodeOut(cond_args.conditional->name(), 0, + cond_args.conditional->output_type(0))); + // ... followed by the other inputs. + builder.Input(inputs); + + TF_RETURN_IF_ERROR(builder.Finalize(&if_def)); + TF_ASSIGN_OR_RETURN(Node * if_node, AddNode(if_def, graph_)); + return if_node; +} + +void FunctionalizeCond::RemoveClusterNodes(Cluster* cluster) { + VLOG(3) << "RemoveClusterNodes for " << cluster->representative; + ClusterHandle repr = cluster->representative; + std::deque to_delete; + for (Node* node : graph_->nodes()) { + if (Representative(node) == repr) { + to_delete.push_back(node); + } + } + for (Node* n : to_delete) { + graph_->RemoveNode(n); + } +} + +template +void FunctionalizeCond::RemoveUnusedArgs(const T& args) { + VLOG(2) << "RemoveUnusedArgs among: " << NodesToString(args); + + std::deque to_delete; + for (Node* arg : args) { + if (IsDeadSwitch(arg)) { + to_delete.push_back(arg); + for (Node* n : arg->out_nodes()) { + to_delete.push_back(n); + } + } + } + for (Node* n : to_delete) { + switch_nodes_.erase(n); + auto it = clustered_graph_.find(Representative(n)); + if (it != clustered_graph_.end()) { + it->second.switch_nodes.erase(n); + } + graph_->RemoveNode(n); + } +} + +Status FunctionalizeCond::ExtractBody(const CondArgs& cond_args, + const Cluster& merge_cluster, + const std::vector& outputs, + int input_edge, Graph* body) { + VLOG(2) << "ExtractBody for " << merge_cluster.representative + << " along edge " << input_edge; + std::vector squash_src_outputs(graph_->num_node_ids(), false); + std::vector node_map(graph_->num_node_ids(), nullptr); + int arg_count = 0; + for (const auto* arg : cond_args.args) { + DataType dtype = arg->input_type(0); + TF_ASSIGN_OR_RETURN(Node * arg_node, + BuildArgNode(body, dtype, arg_count++)); + node_map.at(arg->id()) = arg_node; + squash_src_outputs.at(arg->id()) = true; + } + + std::vector stack; + stack.reserve(outputs.size()); + for (int j = 0; j < outputs.size(); ++j) { + Node* node = outputs[j]; + TF_ASSIGN_OR_RETURN(node_map.at(node->id()), + BuildRetvalNode(body, node->output_type(0), + /*index=*/j)); + const Edge* in_edge; + TF_RETURN_IF_ERROR(node->input_edge(input_edge, &in_edge)); + Node* in = in_edge->src(); + if (node_map.at(in->id()) == nullptr) { + node_map.at(in->id()) = body->CopyNode(in); + } + + if (cond_args.args.find(in) == cond_args.args.end()) { + body->AddEdge(node_map.at(in->id()), in_edge->src_output(), + node_map.at(node->id()), 0); + } else { + body->AddEdge(node_map.at(in->id()), 0, node_map.at(node->id()), 0); + // Don't include input nodes that are already just returned in stack. + continue; + } + stack.push_back(in); + } + + return CopySubgraph(*graph_, nullptr, stack, squash_src_outputs, &node_map, + body); +} + +Status FunctionalizeCond::AddInputEdges(const CondArgs& cond_args, + Node* if_node) { + VLOG(3) << "AddInputEdges for " << if_node->name(); + int i = 0; + graph_->AddEdge(cond_args.conditional, 0, if_node, i++); + for (const Node* arg : cond_args.args) { + const Edge* in_edge; + TF_RETURN_IF_ERROR(arg->input_edge(0, &in_edge)); + if (in_edge->IsControlEdge()) { + graph_->AddControlEdge(in_edge->src(), if_node); + } else { + graph_->AddEdge(in_edge->src(), in_edge->src_output(), if_node, i++); + } + } + return Status::OK(); +} + +Status FunctionalizeCond::AddOutputEdges(const std::vector& outputs, + Node* if_node) { + VLOG(3) << "AddOutputEdges for " << if_node->name(); + for (int i = 0; i < outputs.size(); ++i) { + Node* node = outputs[i]; + std::vector edges(node->out_edges().begin(), + node->out_edges().end()); + for (const Edge* edge : edges) { + Node* dst = edge->dst(); + int dst_input = edge->dst_input(); + + if (edge->src_output() > 0) { + return errors::Unimplemented("Output of index (", edge->src_output(), + ") of merge node ", node->name()); + } + graph_->RemoveEdge(edge); + + int src_output = + dst_input == Graph::kControlSlot ? Graph::kControlSlot : i; + graph_->AddEdge(if_node, src_output, dst, dst_input); + } + } + return Status::OK(); +} + +void FunctionalizeCond::RemoveMergeNodes(Cluster* merge_cluster) { + VLOG(3) << "RemoveMergeNodes for " << merge_cluster->representative; + // Remove all merge nodes now dead post extraction of If. + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + Node* node = *it; + graph_->RemoveNode(node); + merge_cluster->merge_nodes.erase(*it++); + } +} + +Status FunctionalizeCond::RemoveTrivialMerge(Cluster* merge_cluster) { + Cluster* switch_cluster = *merge_cluster->in_nodes.begin(); + if (switch_cluster->switch_nodes.empty()) { + return errors::FailedPrecondition( + "Not a trivial merge: no Switch node feeding into Merge node"); + } + + for (auto it = merge_cluster->merge_nodes.begin(); + it != merge_cluster->merge_nodes.end();) { + // We have the following structure: + // Op -> Switch -> Merge -> Consumer + // and we want to transform it to: + // Op -> Consumer + Node* merge_node = *it; + Node* switch_node; + const Edge* in = nullptr; + TF_RETURN_IF_ERROR(merge_node->input_node(0, &switch_node)); + TF_RETURN_IF_ERROR(switch_node->input_edge(0, &in)); + for (auto out : merge_node->out_edges()) { + int src_output = out->dst_input() == Graph::kControlSlot + ? Graph::kControlSlot + : in->src_output(); + graph_->AddEdge(in->src(), src_output, out->dst(), out->dst_input()); + } + graph_->RemoveNode(*it++); + } + RemoveUnusedArgs(switch_cluster->switch_nodes); + + return Status::OK(); +} + +Status FunctionalizeCond::ConvertMergeToXlaIf(Cluster* merge_cluster) { + VLOG(1) << "ConvertMergeToXlaIf for " << merge_cluster->representative; + gtl::optional switch_cluster = GetSwitchCluster(*merge_cluster); + if (!switch_cluster.has_value()) { + return errors::FailedPrecondition( + "Merge cluster was not part of a simple conditional in the clustered " + "graph. Graph nodes in merge cluster ", + NodesToString(merge_cluster->merge_nodes)); + } + TF_ASSIGN_OR_RETURN(auto cond_args, + DetermineCondArgs(*merge_cluster, **switch_cluster)); + + // Sort the outputs by ID to produce more stable output. + std::vector outputs(merge_cluster->merge_nodes.begin(), + merge_cluster->merge_nodes.end()); + std::sort(outputs.begin(), outputs.end(), CondArgs::CondCmp()); + + // Extract bodies and builds a If operator. + TF_ASSIGN_OR_RETURN(Node * if_node, + BuildAndAddXlaIfOp(cond_args, *merge_cluster, outputs)); + TF_RETURN_IF_ERROR(AddInputEdges(cond_args, if_node)); + TF_RETURN_IF_ERROR(AddOutputEdges(outputs, if_node)); + + // Remove the old nodes from the graph_ and contract the edges of the + // clustered graph. + for (auto in : merge_cluster->in_nodes) { + if (in != *switch_cluster) { + RemoveClusterNodes(in); + } + } + RemoveMergeNodes(merge_cluster); + RemoveUnusedArgs(cond_args.args); + auto in_nodes = merge_cluster->in_nodes; + for (auto it = in_nodes.begin(); it != in_nodes.end();) { + ContractEdge(*it++, merge_cluster); + } + ContractEdge(*switch_cluster, merge_cluster); + clusters_[if_node].Get() = ClusterHandle(merge_cluster->representative); + + return Status::OK(); +} + +std::vector> +FunctionalizeCond::SortedMergeNodes() { + VLOG(2) << "ProcessClusteredGraph"; + std::stack> stack; + for (auto& c : clustered_graph_) { + if (c.second.in_nodes.empty()) { + stack.push({0, &c.second}); + } + } + + // Perform a depth-first traversal of the clustered graph computing the + // switch-merge depth. + std::vector> queue; + std::unordered_set visited; + while (!stack.empty()) { + Cluster* n = stack.top().second; + size_t depth = stack.top().first; + stack.pop(); + + auto inserted = visited.insert(n); + if (!inserted.second) { + continue; + } + + size_t new_depth = depth; + if (!n->merge_nodes.empty()) { + queue.emplace_back(depth, n); + --new_depth; + } + if (!n->switch_nodes.empty()) { + ++new_depth; + } + for (Cluster* e : n->out_nodes) { + stack.emplace(new_depth, e); + } + } + + // Sort in reverse order of switch-merge depth with ties broken by the + // ClusterHandle. + std::sort(queue.begin(), queue.end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return std::tie(lhs.first, lhs.second->representative) > + std::tie(rhs.first, rhs.second->representative); + }); + + return queue; +} + +Status FunctionalizeCond::Functionalize(Graph* graph, + FunctionLibraryDefinition* library) { + VLOG(1) << "FunctionalizeCond::Functionalize"; + FunctionalizeCond fc(graph, library); + fc.CreateClusters(); + if (fc.NoConditionals()) { + return Status::OK(); + } + fc.CreateClusteredGraph(); + + auto queue = fc.SortedMergeNodes(); + for (auto it = queue.begin(); it != queue.end();) { + Cluster* merge_cluster = (*it).second; + ++it; + if (merge_cluster->in_nodes.size() == 1) { + TF_RETURN_IF_ERROR(fc.RemoveTrivialMerge(merge_cluster)); + } else { + TF_RETURN_IF_ERROR(fc.ConvertMergeToXlaIf(merge_cluster)); + } + + // Contract newly Merge free merge_cluster with incoming nodes without + // Switch or Merge nodes. + std::vector in_nodes(merge_cluster->in_nodes.begin(), + merge_cluster->in_nodes.end()); + for (auto in : in_nodes) { + if (in->merge_nodes.empty() && in->switch_nodes.empty()) { + fc.ContractEdge(in, merge_cluster); + } + } + } + + if (!fc.switch_nodes_.empty()) { + return errors::Internal( + "Failed to functionalize control flow with Switch nodes remaining: ", + NodesToString(fc.switch_nodes_)); + } + return Status::OK(); +} + } // namespace // Transformation that converts Tensorflow's graph control flow constructs into // functional equivalents. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library) { - VLOG(2) << "FunctionalizeControlFlow: " + VLOG(2) << "FunctionalizeControlFlow (initial): " << dump_graph::DumpGraphToFile("functionalize_initial", *graph); // Note: BuildControlFlowInfo() requires that the graph's source node is // connected to all source nodes in the graph. Many graphs violate this @@ -577,6 +1457,13 @@ Status FunctionalizeControlFlow(Graph* graph, } } + // FunctionalizeControlFlow is invoked for every function, so the loops's + // bodies and conditionals that were extracted into functions will be handled + // in successive invocations. + TF_RETURN_IF_ERROR(FunctionalizeCond::Functionalize(graph, library)); + + VLOG(2) << "FunctionalizeControlFlow (final): " + << dump_graph::DumpGraphToFile("functionalize_final", *graph); return Status::OK(); } diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.h b/tensorflow/compiler/tf2xla/functionalize_control_flow.h index 1535dc80b0ccdba38c57b534ed7473fc8632e33f..4d4ee3054c2914bb614bf75f7a51be8f6292683e 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow.h +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.h @@ -23,7 +23,6 @@ namespace tensorflow { // Transformation that converts tf.while_loop() loops into functional While // operators, suitable for XLA compilation. -// TODO(b/36470387): add support for conditionals. Status FunctionalizeControlFlow(Graph* graph, FunctionLibraryDefinition* library); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 914c8999a6f13f5f2dc4e3cecc38c91afd432131..01d2b282751f387cfa9c8887cdeb48090c96bff4 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -23,6 +23,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/cc/ops/functional_ops.h" #include "tensorflow/compiler/tf2xla/test_util.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/node_def_util.h" #include "tensorflow/core/framework/op.h" @@ -35,6 +36,134 @@ limitations under the License. namespace tensorflow { namespace { +// Returns the names of the "then" and "else" functions for the XlaIf node in a +// graph. +Status FindIfThenAndElse(const GraphDef& graph, NameAttrList* then_fn, + NameAttrList* else_fn) { + for (const NodeDef& node : graph.node()) { + if (node.op() == "XlaIf") { + const NameAttrList* result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "then_branch", &result)); + *then_fn = *result; + TF_RETURN_IF_ERROR(GetNodeAttr(node, "else_branch", &result)); + *else_fn = *result; + return Status::OK(); + } + } + return errors::NotFound("No XlaIf node found in graph"); +} + +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// y = array_ops.placeholder(dtypes.int32) +// z = control_flow_ops.cond( +// math_ops.less(y, x), lambda: math_ops.multiply(y, 17), +// lambda: math_ops.add(x, 23)) +TEST(FunctionalizeControlFlow, Conditional) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto switch_1 = ops::Switch(scope.WithOpName("cond/Switch"), less, less); + + auto identity_t = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_true); + auto seventeen = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_t), 17); + auto switch_2 = ops::Switch(scope.WithOpName("cond/Switch"), y, less); + auto mul = ops::Multiply(scope.WithOpName("cond/Mul"), switch_2.output_true, + seventeen); + + auto identity_f = + ops::Identity(scope.WithOpName("cond/Identity"), switch_1.output_false); + auto twenty_three = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity_f), 23); + auto switch_3 = ops::Switch(scope.WithOpName("cond/Switch"), x, less); + auto add = ops::Add(scope.WithOpName("cond/false/add"), + switch_3.output_false, twenty_three); + + auto merge = ops::Merge(scope.WithOpName("cond/Merge"), + std::initializer_list{add, mul}); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + NameAttrList then_fn; + NameAttrList else_fn; + TF_EXPECT_OK(FindIfThenAndElse(graph_def, &then_fn, &else_fn)); + InstantiationResultForTest else_result; + TF_EXPECT_OK( + InstantiateFunctionForTest(else_fn.name(), library, &else_result)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto y = ops::Placeholder(scope.WithOpName("y"), DT_INT32); + auto x = ops::Placeholder(scope.WithOpName("x"), DT_INT32); + auto less = ops::Less(scope.WithOpName("cond/Less"), y, x); + auto if_op = ops::XlaIf(scope.WithOpName("cond/Merge_If"), less, + std::initializer_list{less, y, x}, then_fn, + else_fn, {DT_INT32}); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // then body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity"), arg_0); + auto cond = ops::Const( + scope.WithOpName("cond").WithControlDependencies(identity), 17); + auto mul = ops::Mul(scope.WithOpName("cond/Mul"), arg_1, cond); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), mul, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(then_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // else body. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg_0 = ops::_Arg(scope.WithOpName("_arg0"), DT_BOOL, 0); + auto arg_1 = ops::_Arg(scope.WithOpName("_arg1"), DT_INT32, 1); + auto arg_2 = ops::_Arg(scope.WithOpName("_arg2"), DT_INT32, 2); + auto identity = ops::Identity(scope.WithOpName("cond/Identity_1"), arg_0); + auto cond_1 = ops::Const( + scope.WithOpName("cond_1").WithControlDependencies(identity), 23); + auto add = ops::Add(scope.WithOpName("cond/false/add"), arg_2, cond_1); + auto retval0 = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(else_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + EXPECT_EQ((DataTypeVector{DT_BOOL, DT_INT32, DT_INT32}), result.arg_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Returns the names of the "cond" and "body" functions for the While node // in a graph. Status FindWhileCondAndBody(const GraphDef& graph, NameAttrList* cond, @@ -168,6 +297,108 @@ TEST(FunctionalizeControlFlow, OneLoopVar) { } } +// Tests functionalizing OneLoopVar where the loop value is not used post the +// loop. +// Graph: +// x = array_ops.placeholder(dtypes.int32) +// control_flow_ops.while_loop(lambda i: i < 10, lambda i: i + 1, [x]) +TEST(FunctionalizeControlFlow, OneLoopVarWithoutExit) { + Graph graph(OpRegistry::Global()); + { + Scope scope = Scope::NewRootScope().ExitOnError(); + + auto dummy = ops::Placeholder(scope.WithOpName("Dummy"), DT_INT32); + + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto enter = + ops::internal::Enter(scope.WithOpName("while/Enter"), source, "aloop"); + auto merge = ops::Merge(scope.WithOpName("while/Merge"), + std::initializer_list{enter, dummy}); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(merge.output), + 10); + auto less = ops::Less(scope.WithOpName("while/Less"), merge.output, ten); + auto loop_cond = ops::LoopCond(scope.WithOpName("while/LoopCond"), less); + auto switch_ = + ops::Switch(scope.WithOpName("while/Switch"), merge.output, loop_cond); + auto identity = + ops::Identity(scope.WithOpName("while/Identity"), switch_.output_true); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto next_iteration = + ops::NextIteration(scope.WithOpName("while/NextIteration"), add); + + // Remove the dummy node and add the loop backedge. + scope.graph()->RemoveNode(dummy.node()); + scope.graph()->AddEdge(next_iteration.node(), 0, merge.output.node(), 1); + + TF_EXPECT_OK(scope.ToGraph(&graph)); + } + + FunctionLibraryDefinition library(OpRegistry::Global(), {}); + TF_ASSERT_OK(FunctionalizeControlFlow(&graph, &library)); + + GraphDef graph_def; + graph.ToGraphDef(&graph_def); + + NameAttrList cond_fn, body_fn; + TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); + + // Outer graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto source = ops::Placeholder(scope.WithOpName("source"), DT_INT32); + auto while_op = + ops::XlaWhile(scope.WithOpName("while/LoopCond"), + std::initializer_list{source}, cond_fn, body_fn); + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + TF_EXPECT_GRAPH_EQ(expected, graph_def); + } + + // Condition graph + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto ten = ops::Const( + scope.WithOpName("while/Less/y").WithControlDependencies(arg), 10); + auto less = ops::Less(scope.WithOpName("while/Less"), arg, ten); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), less, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(cond_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_BOOL}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } + + // Body graph. + { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto arg = ops::_Arg(scope.WithOpName("_arg0"), DT_INT32, 0); + auto identity = ops::Identity(scope.WithOpName("while/Identity"), arg); + auto one = ops::Const( + scope.WithOpName("while/add/y").WithControlDependencies(identity), 1); + auto add = ops::Add(scope.WithOpName("while/add"), identity, one); + auto retval = ops::_Retval(scope.WithOpName("_retval0_RetVal"), add, 0); + + GraphDef expected; + TF_EXPECT_OK(scope.ToGraphDef(&expected)); + + InstantiationResultForTest result; + TF_EXPECT_OK(InstantiateFunctionForTest(body_fn.name(), library, &result)); + + EXPECT_EQ(DataTypeVector{DT_INT32}, result.arg_types); + EXPECT_EQ(DataTypeVector{DT_INT32}, result.ret_types); + TF_EXPECT_GRAPH_EQ(expected, result.gdef); + } +} + // Graph: // x = array_ops.placeholder(dtypes.int32) // y = array_ops.placeholder(dtypes.int32) diff --git a/tensorflow/compiler/tf2xla/graph_compiler.cc b/tensorflow/compiler/tf2xla/graph_compiler.cc new file mode 100644 index 0000000000000000000000000000000000000000..8062f0c03ca60e88bd5c021092dceb105232219f --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.cc @@ -0,0 +1,245 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/graph_compiler.h" + +#include +#include +#include +#include "tensorflow/compiler/tf2xla/const_analysis.h" +#include "tensorflow/compiler/tf2xla/dump_graph.h" +#include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/literal_util.h" +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/executor.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/common_runtime/graph_optimizer.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/graph/algorithm.h" +#include "tensorflow/core/graph/graph_constructor.h" +#include "tensorflow/core/graph/node_builder.h" +#include "tensorflow/core/lib/gtl/cleanup.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +namespace { +Status PrepareArguments(XlaOpKernelContext* ctx, Graph* graph, + const std::vector& expressions, + std::vector* args) { + auto builder = ctx->builder(); + std::vector compile_time_constant_flags(expressions.size()); + + TF_RETURN_IF_ERROR( + BackwardsConstAnalysis(*graph, &compile_time_constant_flags)); + + args->resize(expressions.size()); + for (int i = 0; i < args->size(); ++i) { + XlaCompiler::Argument& arg = (*args)[i]; + arg.type = ctx->input_type(i); + + TF_RETURN_IF_ERROR( + TensorShapeToXLAShape(arg.type, ctx->InputShape(i), &arg.shape)); + + if (arg.type == DT_RESOURCE) { + return errors::InvalidArgument( + "Resource as function argument is not yet implemented."); + } else if (expressions[i]->has_constant_value()) { + arg.kind = XlaCompiler::Argument::kConstant; + arg.constant_value = expressions[i]->constant_value(); + } else if (compile_time_constant_flags[i]) { + arg.kind = XlaCompiler::Argument::kConstant; + TF_RET_CHECK(expressions[i]->resource() == nullptr) + << "Input with resource is not yet implemented."; + TF_ASSIGN_OR_RETURN(auto literal, + builder->ComputeConstant(expressions[i]->handle())); + TF_RETURN_IF_ERROR( + LiteralToHostTensor(*literal, arg.type, &arg.constant_value)); + } else { + arg.kind = XlaCompiler::Argument::kParameter; + } + } + return Status::OK(); +} +} // namespace +Status GraphCompiler::Compile() { + // Maintain a mapping from node id to node outputs. + using NodeOutputs = std::vector; + std::vector output_registry(graph_->num_node_ids()); + auto output_registry_cleanup = gtl::MakeCleanup([&output_registry] { + for (const NodeOutputs& outputs : output_registry) { + for (const TensorValue& value : outputs) { + CHECK(!value.is_ref()); + delete value.tensor; + } + } + }); + + // XLA requires determinism, generate a stable ordering from DFS. + std::vector topo_sorted_nodes; + GetReversePostOrder(*graph_, &topo_sorted_nodes, + /*stable_comparator=*/NodeComparatorName()); + + OpKernelContext::Params params; + PartiallySetupParams(¶ms); + + for (Node* n : topo_sorted_nodes) { + OpKernel* op_kernel_raw = nullptr; + Status s = flib_->CreateKernel(n->def(), &op_kernel_raw); + // Transfer ownership of the kernel to a local smart pointer. + std::unique_ptr op_kernel(op_kernel_raw); + + if (!s.ok()) { + s = AttachDef(s, *n); + LOG(ERROR) << "Executor failed to create kernel. " << s; + return s; + } + + TF_RET_CHECK(!n->IsRecv() && !n->IsSend() && !n->IsSwitch()) + << "Not supported node: " << n->DebugString(); + params.op_kernel = op_kernel.get(); + gtl::InlinedVector output_attr(n->num_outputs()); + params.output_attr_array = output_attr.data(); + + // tensor_inputs_ is a buffer reused across graph traversal. We clean up and + // reinitialize the buffer before we visit a new node. + tensor_inputs_.clear(); + tensor_inputs_.resize(n->num_inputs()); + + // Set up inputs from outputs of previous nodes. + for (auto* e : n->in_edges()) { + if (e->IsControlEdge()) continue; + Node* src = e->src(); + TF_RET_CHECK(src->id() < output_registry.size()); + const NodeOutputs& src_outputs = output_registry[src->id()]; + + tensor_inputs_[e->dst_input()] = src_outputs[e->src_output()]; + } + + OpKernelContext op_context(¶ms, n->num_outputs()); + if (IsFunctional(n)) { + TF_RETURN_IF_ERROR(CompileFunctionalNode(n, &op_context)); + } else { + device_->Compute(CHECK_NOTNULL(params.op_kernel), &op_context); + Status s = op_context.status(); + TF_RETURN_IF_ERROR(s); + } + + // Set up outputs. Also check if outputs from the previous computation is + // valid. + NodeOutputs& outputs = output_registry[n->id()]; + outputs.resize(n->num_outputs()); + for (int o = 0; o < n->num_outputs(); ++o) { + outputs[o] = op_context.release_output(o); + if (*op_context.is_output_dead() || outputs[o].tensor == nullptr) { + return errors::Internal("Missing xla_context ", o, "-th output from ", + (*op_context.is_output_dead() ? "(dead)" : ""), + SummarizeNode(*n)); + } + } + } + return Status::OK(); +} + +bool GraphCompiler::IsFunctional(Node* n) { + return n->type_string() == FunctionLibraryDefinition::kGradientOp || + (flib_->GetFunctionLibraryDefinition()->Find(n->def().op()) != + nullptr); +} + +Status GraphCompiler::CompileFunctionalNode(Node* n, + OpKernelContext* op_context) { + TF_RET_CHECK(IsFunctional(n)); + // For functional nodes, compile them using compiler from the context and call + // into the functions. + XlaOpKernelContext xla_op_context(op_context); + + XlaCompiler* compiler = xla_op_context.compiler(); + + NameAttrList func; + if (flib_->GetFunctionLibraryDefinition()->Find(n->def().op())) { + func.set_name(n->def().op()); + } else { + func.set_name(FunctionLibraryDefinition::kGradientOp); + } + *func.mutable_attr() = n->def().attr(); + + std::vector expressions; + + for (auto tensor : tensor_inputs_) { + auto expression = + reinterpret_cast(tensor->tensor_data().data()); + expressions.push_back(expression); + } + + // Prepare the arguments and compile the function. + std::vector arguments; + const FunctionBody* fbody; + TF_RETURN_IF_ERROR(compiler->FindFunctionBody(func, &fbody)); + + auto graph = compiler->GetGraph(fbody); + + TF_RETURN_IF_ERROR( + PrepareArguments(&xla_op_context, graph.get(), expressions, &arguments)); + + XlaCompiler::CompilationResult result; + + TF_RETURN_IF_ERROR(compiler->CompileFunction(XlaCompiler::CompileOptions(), + func, arguments, &result)); + + TF_RET_CHECK(arguments.size() == expressions.size()); + + std::vector handles; + for (int64 i = 0; i < expressions.size(); ++i) { + if (arguments[i].kind == XlaCompiler::Argument::kConstant) { + continue; + } + handles.push_back(expressions[i]->handle()); + } + + XlaContext& context = XlaContext::Get(op_context); + auto* b = context.builder(); + + auto output_handle = b->Call(*result.computation, handles); + // The output handle of `Call` computation is a tuple type. Unzip it so + // that it can fit into future computations. + for (int64 i = 0; i < n->num_outputs(); ++i) { + if (result.outputs[i].is_constant) { + xla_op_context.SetConstantOutput(i, result.outputs[i].constant_value); + } else { + xla_op_context.SetOutput(i, b->GetTupleElement(output_handle, i)); + } + } + return b->first_error(); +} + +void GraphCompiler::PartiallySetupParams(OpKernelContext::Params* params) { + params->device = device_; + params->inputs = &tensor_inputs_; + params->step_container = step_container_; + params->resource_manager = device_->resource_manager(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/graph_compiler.h b/tensorflow/compiler/tf2xla/graph_compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..ba00160b6d78c1e55cc2e053cd5285344e0179fb --- /dev/null +++ b/tensorflow/compiler/tf2xla/graph_compiler.h @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ +#define TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ + +#include "tensorflow/compiler/tf2xla/xla_compilation_device.h" +#include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/core/common_runtime/device.h" +#include "tensorflow/core/common_runtime/device_mgr.h" +#include "tensorflow/core/common_runtime/function.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/notification.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/public/version.h" + +namespace tensorflow { + +// GraphCompiler compiles the graph in topological order in the current +// thread. It also resolves the nondeterminism in the graph by enforcing a +// total order on all inputs to a node. This abstraction helps us create the +// same XLA computation given two structurally equivalent TensorFlow graphs. +// If a function call is visited during the graph traversal, it is then +// compiled through the xla_context into a computation and a `Call` operation +// is inserted to call into that computation. +// +// Note: GraphCompiler was created to remove our dependency to TF Executor in +// the history. There are still some todos so that we can completely decouple +// from Executor. +// +// TODO(yunxing): Remove usage of XlaCompilationDevice. +// +// TODO(yunxing): Remove the hack that wraps XlaExpression within a tensor now +// that we don't use TF Executor to pass around a tensor. +// +// TODO(yunxing): Make XlaOpkernel not a subclass of OpKernel so that it can +// handle a XlaExpression directly instead of a Tensor. This may require our own +// op registration infrastructure instead of FunctionLibraryRuntime. +class GraphCompiler { + public: + GraphCompiler(XlaContext* xla_context, XlaCompilationDevice* device, + Graph* graph, FunctionLibraryRuntime* flib, + ScopedStepContainer* step_container) + : xla_context_(xla_context), + device_(device), + graph_(graph), + flib_(flib), + step_container_(step_container) {} + + // Compiles the graph. The results are written in `xla_context` that is passed + // into the compiler. + Status Compile(); + + private: + // Partially sets params. This partially set params can be reused + // across multple nodes visit. + void PartiallySetupParams(OpKernelContext::Params* params); + + // Tests if a node is a functional node. A functional node represents a + // defined computation and should be compiled using `compiler_`. + bool IsFunctional(Node* n); + + // Compiles a functional node and writes result to OpkernelContext. A + // functional node represents a defined computation and should be compiled + // using `compiler_`. + Status CompileFunctionalNode(Node* n, OpKernelContext* op_context); + + XlaContext* xla_context_; + XlaCompilationDevice* device_; + Graph* graph_; + FunctionLibraryRuntime* flib_; + ScopedStepContainer* step_container_; + // A buffer to hold tensor inputs to a node, this is reused across the graph + // traversal. + gtl::InlinedVector tensor_inputs_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_GRAPH_COMPILER_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD index 4cff41a516901e9f0b0c58c9ce13522c0916e3cf..13d06177f0fe2eb1a71e5cf684d74d87e263cfc5 100644 --- a/tensorflow/compiler/tf2xla/kernels/BUILD +++ b/tensorflow/compiler/tf2xla/kernels/BUILD @@ -5,7 +5,6 @@ package( ) load("//tensorflow:tensorflow.bzl", "tf_kernel_library") -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") tf_kernel_library( name = "xla_ops", @@ -19,18 +18,23 @@ tf_kernel_library( "bias_ops.cc", "binary_ops.cc", "cast_op.cc", + "categorical_op.cc", "concat_op.cc", "const_op.cc", "conv_ops.cc", "cross_op.cc", "cwise_ops.cc", + "cwise_ops.h", + "depthtospace_op.cc", "diag_op.cc", "dynamic_stitch_op.cc", "elu_op.cc", "fill_op.cc", "function_ops.cc", "gather_op.cc", + "gather_op_helpers.h", "identity_op.cc", + "index_ops.cc", "l2loss_op.cc", "lrn_ops.cc", "matmul_op.cc", @@ -43,6 +47,7 @@ tf_kernel_library( "quantize_and_dequantize_op.cc", "random_ops.cc", "reduction_ops.cc", + "reduction_ops.h", "reduction_ops_common.cc", "relu_op.cc", "reshape_op.cc", @@ -56,6 +61,7 @@ tf_kernel_library( "slice_op.cc", "softmax_op.cc", "spacetobatch_op.cc", + "spacetodepth_op.cc", "split_op.cc", "stack_ops.cc", "strided_slice_op.cc", @@ -68,10 +74,8 @@ tf_kernel_library( "variable_ops.cc", ], hdrs = [ - "cwise_ops.h", "gather_op.h", - "gather_op_helpers.h", - "reduction_ops.h", + "index_ops.h", ], deps = [ ":while_op", @@ -79,23 +83,30 @@ tf_kernel_library( "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/ops:sendrecv_ops", "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:concat_lib", + "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:control_flow_ops", "//tensorflow/core/kernels:conv_ops", "//tensorflow/core/kernels:cwise_op", "//tensorflow/core/kernels:no_op", "//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:pooling_ops", + "//tensorflow/core/kernels:random_op", + "//tensorflow/core/kernels:resource_variable_ops", "//tensorflow/core/kernels:sendrecv_ops", + "//tensorflow/core/kernels:sparse_to_dense_op", + "//tensorflow/core/kernels:stack_ops", + "//tensorflow/core/kernels:training_ops", "//tensorflow/core/kernels:transpose_op", ], ) @@ -118,17 +129,10 @@ tf_kernel_library( # Kernels that only work on CPU, because they use XLA custom calls. # Only link this when using the CPU backend for XLA. -# -# TODO(cwhipkey): move into xla_ops when ops can be registered for -# CPU compilation only (b/31363654). tf_kernel_library( name = "xla_cpu_only_ops", - srcs = [ - "index_ops.cc", - ], + srcs = ["index_ops_cpu.cc"], deps = [ - ":gather_op_kernel_float_int32", - ":gather_op_kernel_float_int64", ":index_ops_kernel_argmax_float_1d", ":index_ops_kernel_argmax_float_2d", "//tensorflow/compiler/tf2xla:common", @@ -137,46 +141,19 @@ tf_kernel_library( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client/lib:arithmetic", - "//tensorflow/core:all_kernels", "//tensorflow/core:framework", "//tensorflow/core:lib", + "//tensorflow/core/kernels:argmax_op", "//tensorflow/core/kernels:bounds_check", ], ) -cc_library( - name = "gather_op_kernel_float_int32", - srcs = ["gather_op_kernel_float_int32.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -cc_library( - name = "gather_op_kernel_float_int64", - srcs = ["gather_op_kernel_float_int64.cc"], - visibility = ["//visibility:public"], - deps = [ - "//tensorflow/compiler/tf2xla:xla_local_runtime_context", - "//tensorflow/core:framework_lite", - "//tensorflow/core/kernels:bounds_check", - "//tensorflow/core/kernels:gather_functor_hdr", - "//third_party/eigen3", - ], - alwayslink = 1, -) - cc_library( name = "index_ops_kernel_argmax_float_1d", srcs = ["index_ops_kernel_argmax_float_1d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], @@ -188,6 +165,7 @@ cc_library( srcs = ["index_ops_kernel_argmax_float_2d.cc"], visibility = ["//visibility:public"], deps = [ + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/core:framework_lite", "//third_party/eigen3", ], diff --git a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc index 16b778bca439b9236498945f132e8095baeb71c1..73ccc151c1d6bdf70105badd962903297f090abe 100644 --- a/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batch_matmul_op.cc @@ -77,7 +77,13 @@ class BatchMatMulOp : public XlaOpKernel { xla::ComputationBuilder* builder = ctx->builder(); xla::ComputationDataHandle x_handle = ctx->Input(0); + if (BaseType(input_type(0)) == DT_COMPLEX64 && adj_x_) { + x_handle = builder->Conj(x_handle); + } xla::ComputationDataHandle y_handle = ctx->Input(1); + if (BaseType(input_type(1)) == DT_COMPLEX64 && adj_y_) { + y_handle = builder->Conj(y_handle); + } // Reshape input tensors into 3D tensors by flattening the batch // dimensions. This makes it easier to unroll the batch dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc index bc2cd31230dfe9ca35540341d225dcb768fa34f6..bb031b8c471e08ba90c554e309b850a26c3edae0 100644 --- a/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/bcast_ops.cc @@ -27,6 +27,46 @@ limitations under the License. namespace tensorflow { namespace { +// Given shapes of two tensors, computes the broadcast shape. +class BCastArgsOp : public XlaOpKernel { + public: + explicit BCastArgsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->MatchSignature({DT_INT32, DT_INT32}, {DT_INT32})); + } + + void Compile(XlaOpKernelContext* ctx) override { + OP_REQUIRES( + ctx, ctx->num_inputs() == 2, + errors::Unimplemented("Broadcast for n-ary operations (n > 2)")); + gtl::InlinedVector shapes; + for (int i = 0; i < ctx->num_inputs(); ++i) { + const TensorShape in_shape = ctx->InputShape(i); + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(in_shape), + errors::InvalidArgument("In[", i, "] must be a vector.", + in_shape.DebugString())); + std::vector shape; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(i, &shape)); + shapes.push_back(BCast::Vec(shape.begin(), shape.end())); + } + BCast bcast(shapes[0], shapes[1]); + OP_REQUIRES(ctx, bcast.IsValid(), + errors::InvalidArgument( + "Incompatible shapes: [", str_util::Join(shapes[0], ","), + "] vs. [", str_util::Join(shapes[1], ","), "]")); + + const int64 len = bcast.output_shape().size(); + Tensor output(DT_INT32, TensorShape({len})); + for (int64 i = 0; i < len; ++i) { + output.flat()(i) = static_cast(bcast.output_shape()[i]); + } + ctx->SetConstantOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(BCastArgsOp); +}; +REGISTER_XLA_OP(Name("BroadcastArgs"), BCastArgsOp); + // Given shapes of two tensors, computes the reduction indices for the // gradient computation. // diff --git a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc index 58538b45137b26ed5aa296eb6c1077e88aea72b9..1de91924326464338352b1ac9edf77141f25ad35 100644 --- a/tensorflow/compiler/tf2xla/kernels/binary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/binary_ops.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Native XLA implementations of simple unary Ops +// Native XLA implementations of simple binary Ops #include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { namespace { @@ -50,6 +51,9 @@ XLA_MAKE_BINARY(Sub, b->Sub(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mul, b->Mul(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Div, b->Div(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Atan2, b->Atan2(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(Complex, b->Complex(lhs, rhs, extend_dimensions)); + // Implementation of FloorDiv. Pseudo-code: // if ((x < 0) != (y < 0)) { // T abs_x = std::abs(x); @@ -96,8 +100,17 @@ static xla::ComputationDataHandle FloorModImpl(xla::ComputationBuilder* b, XLA_MAKE_BINARY(FloorMod, FloorModImpl(b, input_type(0), lhs, rhs, broadcast_helper)); -XLA_MAKE_BINARY(LogicalAnd, b->LogicalAnd(lhs, rhs, extend_dimensions)); -XLA_MAKE_BINARY(LogicalOr, b->LogicalOr(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(BitwiseOr, b->Or(lhs, rhs, extend_dimensions)); + +XLA_MAKE_BINARY(LeftShift, b->ShiftLeft(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(RightShift, + (DataTypeIsUnsigned(ctx->input_type(0)) + ? b->ShiftRightLogical(lhs, rhs, extend_dimensions) + : b->ShiftRightArithmetic(lhs, rhs, extend_dimensions))); + +XLA_MAKE_BINARY(LogicalAnd, b->And(lhs, rhs, extend_dimensions)); +XLA_MAKE_BINARY(LogicalOr, b->Or(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Mod, b->Rem(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Maximum, b->Max(lhs, rhs, extend_dimensions)); XLA_MAKE_BINARY(Minimum, b->Min(lhs, rhs, extend_dimensions)); @@ -162,8 +175,12 @@ class ApproximateEqualOp : public XlaOpKernel { // Computes the max of the scalar input x and 0. void Compile(XlaOpKernelContext* ctx) override { xla::ComputationBuilder* b = ctx->builder(); - auto result = b->Lt(b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))), - XlaHelpers::FloatLiteral(b, input_type(0), tolerance_)); + auto abs = b->Abs(b->Sub(ctx->Input(0), ctx->Input(1))); + auto abs_shape = b->GetShape(abs); + OP_REQUIRES_OK(ctx, abs_shape.status()); + auto abs_type = abs_shape.ValueOrDie()->element_type(); + auto result = b->Lt( + abs, b->ConvertElementType(b->ConstantR0(tolerance_), abs_type)); ctx->SetOutput(0, result); } diff --git a/tensorflow/compiler/tf2xla/kernels/cast_op.cc b/tensorflow/compiler/tf2xla/kernels/cast_op.cc index 2331520230176fce7646d89140851fe37aee5fda..43a6a747c6bcc441f33f276fde4a66f367d99731 100644 --- a/tensorflow/compiler/tf2xla/kernels/cast_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/cast_op.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" namespace tensorflow { @@ -40,6 +41,11 @@ class CastOp : public XlaOpKernel { output = input; } else if (dst_dtype_ == DT_BOOL) { output = builder->Ne(input, XlaHelpers::Zero(builder, src_dtype_)); + } else if (xla::primitive_util::IsComplexType(src_type_) && + !xla::primitive_util::IsComplexType(dst_type_)) { + // As in cast_op.h, we replicate the numpy behavior of truncating the + // imaginary part. + output = builder->ConvertElementType(builder->Real(input), dst_type_); } else { output = builder->ConvertElementType(input, dst_type_); } diff --git a/tensorflow/compiler/tf2xla/kernels/categorical_op.cc b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..592f3ecc3ce2abf33ddffe8b0e59c4e12e73e956 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/categorical_op.cc @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// XLA implementations of Categorical op. + +#include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" + +namespace tensorflow { +namespace { + +class CategoricalOp : public XlaOpKernel { + public: + explicit CategoricalOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + // Get the logits + const xla::ComputationDataHandle& logits = ctx->Input(0); + TensorShape logits_shape = ctx->InputShape(0); + int64 num_samples; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &num_samples)); + OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_shape), + errors::InvalidArgument("logits should be a matrix, got shape ", + logits_shape.DebugString())); + OP_REQUIRES(ctx, num_samples >= 0, + errors::InvalidArgument( + "num_samples should be nonnegative, got ", num_samples)); + + for (int i = 0; i < 2; i++) { + const int64 dim = logits_shape.dim_size(i); + OP_REQUIRES( + ctx, static_cast(dim) == dim, + errors::InvalidArgument("logits.shape = ", logits_shape.DebugString(), + " too large for int")); + } + + const int64 batch_size = logits_shape.dim_size(0); + const int64 num_classes = logits_shape.dim_size(1); + + xla::ComputationBuilder* builder = ctx->builder(); + + std::array uniform_shape_array = { + {batch_size, num_samples, num_classes}}; + xla::PrimitiveType uniform_xla_type; + OP_REQUIRES_OK(ctx, + DataTypeToPrimitiveType(input_type(0), &uniform_xla_type)); + xla::Shape uniform_shape = + xla::ShapeUtil::MakeShape(uniform_xla_type, uniform_shape_array); + auto uniforms = builder->RngUniform( + XlaHelpers::Zero(builder, input_type(0)), + XlaHelpers::One(builder, input_type(0)), uniform_shape); + + // Use Gumbel softmax trick to generate categorical samples. + // See: + // https://hips.seas.harvard.edu/blog/2013/04/06/the-gumbel-max-trick-for-discrete-distributions/ + // TODO(b/68769470): Switch to using a cumulative sum approach. + auto softmax_entries = + builder->Sub(logits, builder->Log(builder->Neg(builder->Log(uniforms))), + /*broadcast_dimensions=*/{0, 2}); + + TensorShape softmax_shape(uniform_shape_array); + xla::ComputationDataHandle argmax; + OP_REQUIRES_OK( + ctx, + XlaHelpers::ArgMax(builder, ctx, softmax_entries, softmax_shape, + input_type(0), output_type(0), /*axis=*/2, &argmax)); + + ctx->SetOutput(0, argmax); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(CategoricalOp); +}; + +// TODO(b/68769717): Rename this sampler to Categorical. +REGISTER_XLA_OP(Name("Multinomial"), CategoricalOp); + +} // anonymous namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 0091b66d28ad62fcd5c0f3b09e90fed8347bb661..885f716afafca7ba23770e38f6693eed1ba50982 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -179,8 +179,10 @@ class ConvOp : public XlaOpKernel { xla::ConvolutionDimensionNumbers dims; std::vector window_strides; - dims.set_batch_dimension(GetTensorBatchDimIndex(num_dims(), data_format_)); - dims.set_feature_dimension(feature_dim); + dims.set_input_batch_dimension(batch_dim); + dims.set_output_batch_dimension(batch_dim); + dims.set_input_feature_dimension(feature_dim); + dims.set_output_feature_dimension(feature_dim); for (int i = 0; i < num_spatial_dims_; ++i) { int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i); dims.add_spatial_dimensions(input_dim); @@ -285,8 +287,10 @@ class ConvBackpropInputOp : public XlaOpKernel { // comment at the top of conv_grad_ops.h for details. xla::ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(batch_dim); - dnums.set_feature_dimension(feature_dim); + dnums.set_input_batch_dimension(batch_dim); + dnums.set_output_batch_dimension(batch_dim); + dnums.set_input_feature_dimension(feature_dim); + dnums.set_output_feature_dimension(feature_dim); // TF filter shape is [ H, W, ..., inC, outC ] // Transpose the input and output features for computing the gradient. @@ -419,8 +423,10 @@ class ConvBackpropFilterOp : public XlaOpKernel { // Each spatial entry has size in_depth * batch // Swap n_dim and c_dim in the activations. - dnums.set_batch_dimension(c_dim); - dnums.set_feature_dimension(n_dim); + dnums.set_input_batch_dimension(c_dim); + dnums.set_output_batch_dimension(c_dim); + dnums.set_input_feature_dimension(n_dim); + dnums.set_output_feature_dimension(n_dim); // The gradients become the RHS of the convolution. // The gradients have shape [batch, out_rows, out_cols, ..., out_depth] diff --git a/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4ea65ea89e348cb77412efb0c5c0fcb1a9f33f3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/depthtospace_op.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class DepthToSpaceOp : public XlaOpKernel { + public: + explicit DepthToSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_tensor_shape = ctx->InputShape(0); + // The input is presumed to be [batch, height, width, depth] + int input_rank = input_tensor_shape.dims(); + static const int kRequiredDims = 4; + OP_REQUIRES(ctx, kRequiredDims == input_rank, + errors::InvalidArgument("Input rank should be: ", kRequiredDims, + " instead of: ", input_rank)); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + + // 1. Reshape `input` to `reshaped` of shape: + // + // [batch, + // input_shape[1], + // input_shape[2], + // block_size_, + // block_size_, + // depth / (block_size_ * block_size_)] + OP_REQUIRES(ctx, input_shape[3] % (block_size_ * block_size_) == 0, + errors::InvalidArgument( + "Input depth dimension (", input_shape[3], + ") is not divisible by square of the block size (", + block_size_, ")")); + xla::ComputationDataHandle reshaped = b->Reshape( + input, {input_shape[0], input_shape[1], input_shape[2], block_size_, + block_size_, input_shape[3] / (block_size_ * block_size_)}); + + // 2. Permute dimensions of `reshaped` to produce + // `permuted_reshaped` of shape: + // + // [batch, + // input_shape[1], + // block_size_, + // input_shape[2], + // block_size_, + // depth / (block_size_ * block_size_)] + xla::ComputationDataHandle permuted_reshaped = + b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + + // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch, + // input_shape[1] * block_size_, + // input_shape[2] * block_size_, + // depth / (block_size_ * block_size_)] + // + xla::ComputationDataHandle output = b->Reshape( + permuted_reshaped, {input_shape[0], input_shape[1] * block_size_, + input_shape[2] * block_size_, + input_shape[3] / (block_size_ * block_size_)}); + + ctx->SetOutput(0, output); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("DepthToSpace"), DepthToSpaceOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc index dde7898015e73190c96fa6effddfd3fc892264ea..7349dcb987cd88c423570889c0502d1a0bd12c52 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_stitch_op.cc @@ -199,6 +199,7 @@ class DynamicStitchOp : public XlaOpKernel { }; REGISTER_XLA_OP(Name("DynamicStitch"), DynamicStitchOp); +REGISTER_XLA_OP(Name("ParallelDynamicStitch"), DynamicStitchOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 17de565f2cb5db4a9ea1f0272c2deebcd052138e..e420f21ca33fe7de9b33f404ce04eae62d9c041e 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -29,18 +29,22 @@ namespace tensorflow { xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* context, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, DataType dtype, - xla::ComputationBuilder* builder) { + const TensorShape& indices_shape, int64 axis, DataType dtype, + DataType index_type, xla::ComputationBuilder* builder) { // Although the indices Tensor is flattened into rank 1 during the lookup, // and each scalar entry is used as an index into the first dimension of the - // input, the output is returned with shape indices.shape + input.shape[1:] + // input, the output is returned with shape: + // input.shape[:axis] + indices.shape + input.shape[axis+1:] const int num_indices = indices_shape.num_elements(); - TensorShape input_shape_1(input_shape); - input_shape_1.RemoveDim(0); + TensorShape input_shape_pre_axis(input_shape); + input_shape_pre_axis.RemoveDimRange(axis, input_shape.dims()); + TensorShape input_shape_post_axis(input_shape); + input_shape_post_axis.RemoveDimRange(0, axis + 1); - // Each slice of the input tensor is [1, ] + // Each slice of the input tensor has shape: + // [, 1, ] TensorShape slice_shape(input_shape); - slice_shape.set_dim(0, 1); + slice_shape.set_dim(axis, 1); // TODO(b/37575001) The tensor in which we construct the output during // the loop must have rank >= 3 as a workaround for lowering issues. @@ -49,19 +53,23 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( TensorShape loop_out_shape; for (int64 k = 0; k < extra_dims; ++k) loop_out_shape.AddDim(1); + loop_out_shape.AppendShape(input_shape_pre_axis); loop_out_shape.AddDim(num_indices); - loop_out_shape.AppendShape(input_shape_1); + loop_out_shape.AppendShape(input_shape_post_axis); // Slices are reshaped into the rank >= 3 shape of the loop carried output. TensorShape loop_out_slice_shape; for (int64 k = 0; k < extra_dims; ++k) loop_out_slice_shape.AddDim(1); + loop_out_slice_shape.AppendShape(input_shape_pre_axis); loop_out_slice_shape.AddDim(1); - loop_out_slice_shape.AppendShape(input_shape_1); + loop_out_slice_shape.AppendShape(input_shape_post_axis); // Finally, the loop-carried rank >= 3 output is reshaped to the op's // specified result shape. - TensorShape out_shape(indices_shape); - out_shape.AppendShape(input_shape_1); + TensorShape out_shape; + out_shape.AppendShape(input_shape_pre_axis); + out_shape.AppendShape(indices_shape); + out_shape.AppendShape(input_shape_post_axis); // Degenerate case: empty indices. if (num_indices == 0) { @@ -72,22 +80,23 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Specify the shape of the loop-carried Tensor tuple. xla::PrimitiveType ptype; TF_CHECK_OK(DataTypeToPrimitiveType(dtype, &ptype)); + xla::PrimitiveType idxtype; + TF_CHECK_OK(DataTypeToPrimitiveType(index_type, &idxtype)); std::vector tuple_shapes( {// The iteration counter i is a scalar, incremented each iteration. - xla::ShapeUtil::MakeShape(xla::S32, {}), + xla::ShapeUtil::MakeShape(idxtype, {}), // The input array has shape input_shape. Loop invariant. xla::ShapeUtil::MakeShape(ptype, input_shape.dim_sizes()), // The gather indices are reshaped to rank 1. Loop invariant. - xla::ShapeUtil::MakeShape(xla::S32, {num_indices}), + xla::ShapeUtil::MakeShape(idxtype, {num_indices}), // The output array is rank >= 3, and is updated on each loop iteration. xla::ShapeUtil::MakeShape(ptype, loop_out_shape.dim_sizes())}); xla::Shape tuple_shape = xla::ShapeUtil::MakeTupleShape(tuple_shapes); // Construct the initial values of the loop-carried Tensors. - auto init_i = builder->ConstantR0(0); - auto init_out = - builder->Broadcast(builder->ConstantLiteral(xla::Literal::Zero(ptype)), - loop_out_shape.dim_sizes()); + auto init_i = XlaHelpers::Zero(builder, index_type); + auto init_out = builder->Broadcast(XlaHelpers::Zero(builder, dtype), + loop_out_shape.dim_sizes()); // Flatten the indices into 1-D for ease of iteration. auto indices_1d = builder->Reshape(indices, {num_indices}); auto init = builder->Tuple({init_i, input, indices_1d, init_out}); @@ -97,7 +106,7 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( "GatherWhileCond"); condb.Lt(condb.GetTupleElement( condb.Parameter(0, tuple_shape, "GatherWhileTuple"), 0), - condb.ConstantR0(num_indices)); + XlaHelpers::IntegerLiteral(&condb, index_type, num_indices)); auto cond_status = condb.Build(); auto cond = cond_status.ConsumeValueOrDie(); @@ -118,24 +127,27 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( // Slice from the input array. auto index = bodyb.DynamicSlice(indices, bodyb.Reshape(i, {1}), {1}); - auto start_indices = - bodyb.Pad(bodyb.Reshape(index, {1}), bodyb.ConstantR0(0), - xla::MakeEdgePaddingConfig({{0, input_shape.dims() - 1}})); + auto start_indices = bodyb.Pad( + bodyb.Reshape(index, {1}), XlaHelpers::Zero(&bodyb, index_type), + xla::MakeEdgePaddingConfig( + {{input_shape_pre_axis.dims(), input_shape_post_axis.dims()}})); auto slice_i = bodyb.Reshape( bodyb.DynamicSlice(input, start_indices, slice_shape.dim_sizes()), loop_out_slice_shape.dim_sizes()); // Construct the index into the R3+ output Tensor 0, ..., , 0, ... std::vector out_index_vals( - loop_out_shape.dims(), bodyb.ConstantR1({0})); - out_index_vals[extra_dims] = bodyb.Reshape(i, {1}); + loop_out_shape.dims(), + bodyb.Reshape(XlaHelpers::Zero(&bodyb, index_type), {1})); + out_index_vals[input_shape_pre_axis.dims() + extra_dims] = + bodyb.Reshape(i, {1}); auto out_index = bodyb.ConcatInDim(out_index_vals, 0); // Update the output Tensor auto updated_output = bodyb.DynamicUpdateSlice(output, slice_i, out_index); - bodyb.Tuple({bodyb.Add(i, bodyb.ConstantR0(1)), input, indices, - updated_output}); + bodyb.Tuple({bodyb.Add(i, XlaHelpers::One(&bodyb, index_type)), input, + indices, updated_output}); } auto body_status = bodyb.Build(); auto body = body_status.ConsumeValueOrDie(); @@ -146,124 +158,6 @@ xla::ComputationDataHandle XlaComputeGatherDynamicSlice( return builder->Reshape(gather_output, out_shape.dim_sizes()); } -namespace { - -class GatherOpCustomCall : public XlaOpKernel { - public: - explicit GatherOpCustomCall(OpKernelConstruction* context) - : XlaOpKernel(context) {} - - void Compile(XlaOpKernelContext* context) override { - const TensorShape params_shape = context->InputShape(0); - const auto params_dims = params_shape.dims(); - const TensorShape indices_shape = context->InputShape(1); - OP_REQUIRES( - context, TensorShapeUtils::IsVectorOrHigher(params_shape), - errors::InvalidArgument("params must be at least 1 dimensional")); - - DataType index_type = input_type(1); - OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, - errors::InvalidArgument("index must be int32 or int64")); - - // GatherV2 added an axis argument. We support both Gather and GatherV2 in - // this kernel by defaulting axis to 0 if there are 2 inputs. - int64 axis = 0; - if (context->num_inputs() == 3) { - const TensorShape axis_shape = context->InputShape(2); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), - errors::InvalidArgument("axis must be scalar")); - DataType axis_type = input_type(2); - OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, - errors::InvalidArgument("axis must be int32 or int64")); - - xla::Literal literal; - OP_REQUIRES_OK(context, context->ConstantInput(2, &literal)); - int64 axis_input = axis_type == DT_INT32 ? literal.Get({}) - : literal.Get({}); - axis = axis_input < 0 ? axis_input + params_dims : axis_input; - OP_REQUIRES(context, 0 <= axis && axis < params_dims, - errors::InvalidArgument("Expected axis in the range [", - -params_dims, ", ", params_dims, - "), but got ", axis_input)); - } - - // Check that we have enough index space. - const int64 limit = index_type == DT_INT32 - ? std::numeric_limits::max() - : std::numeric_limits::max(); - OP_REQUIRES(context, params_shape.dim_size(axis) <= limit, - errors::InvalidArgument( - "params.shape[", axis, "] too large for ", - DataTypeString(index_type), - " indexing: ", params_shape.dim_size(axis), " > ", limit)); - - // The result shape is params.shape[0:axis] + indices.shape + - // params.shape[axis + 1:]. - TensorShape result_shape; - int64 outer_size = 1; - int64 inner_size = 1; - for (int i = 0; i < axis; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - outer_size *= params_shape.dim_size(i); - } - result_shape.AppendShape(indices_shape); - for (int i = axis + 1; i < params_dims; i++) { - result_shape.AddDim(params_shape.dim_size(i)); - inner_size *= params_shape.dim_size(i); - } - - XlaContext& tc = XlaContext::Get(context); - OP_REQUIRES( - context, tc.allow_cpu_custom_calls(), - errors::InvalidArgument("Gather op requires CustomCall on CPU")); - - xla::ComputationBuilder& b = *context->builder(); - - // Call gather_xla_float_kernel (from gather_op_kernel_float.cc). - // XLA passes to the function, so it is not included here. - std::vector args; - args.push_back(tc.GetOrCreateRuntimeContextParameter()); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(indices_shape.num_elements()))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(outer_size))); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR0(params_shape.dim_size(axis)))); - args.push_back( - b.ConstantLiteral(*xla::Literal::CreateR0(inner_size))); - args.push_back(context->Input(0)); - args.push_back(context->Input(1)); - - xla::Shape xla_out_shape; - OP_REQUIRES_OK( - context, TensorShapeToXLAShape(DT_FLOAT, result_shape, &xla_out_shape)); - - // Call the custom code with args: - xla::ComputationDataHandle output; - if (index_type == DT_INT32) { - output = b.CustomCall("gather_float_int32_xla_impl", args, xla_out_shape); - } else { - output = b.CustomCall("gather_float_int64_xla_impl", args, xla_out_shape); - } - - context->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(GatherOpCustomCall); -}; - -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); -REGISTER_XLA_OP(Name("GatherV2") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_CPU_XLA_JIT), - GatherOpCustomCall); - -} // namespace - GatherOpDynamicSlice::GatherOpDynamicSlice(OpKernelConstruction* context) : XlaOpKernel(context) {} @@ -273,14 +167,37 @@ void GatherOpDynamicSlice::Compile(XlaOpKernelContext* context) { auto input_shape = context->InputShape(0); auto indices = context->Input(1); auto indices_shape = context->InputShape(1); + int64 axis = 0; + if (context->num_inputs() == 3) { + const TensorShape axis_shape = context->InputShape(2); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(axis_shape), + errors::InvalidArgument("axis must be scalar")); + DataType axis_type = input_type(2); + OP_REQUIRES(context, axis_type == DT_INT32 || axis_type == DT_INT64, + errors::InvalidArgument("axis must be int32 or int64")); + + OP_REQUIRES_OK(context, context->ConstantInputAsIntScalar(2, &axis)); + const auto params_dims = input_shape.dims(); + if (axis < 0) { + axis += params_dims; + } + OP_REQUIRES( + context, 0 <= axis && axis < params_dims, + errors::InvalidArgument("Expected axis in the range [", -params_dims, + ", ", params_dims, "), but got ", axis)); + } + + DataType index_type = input_type(1); + OP_REQUIRES(context, index_type == DT_INT32 || index_type == DT_INT64, + errors::InvalidArgument("indices must be int32 or int64")); + xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - context, input, input_shape, indices, indices_shape, DT_FLOAT, builder); + context, input, input_shape, indices, indices_shape, axis, input_type(0), + index_type, builder); context->SetOutput(0, gather); } -REGISTER_XLA_OP(Name("Gather") - .TypeConstraint("Tparams", DT_FLOAT) - .Device(DEVICE_GPU_XLA_JIT), - GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("Gather"), GatherOpDynamicSlice); +REGISTER_XLA_OP(Name("GatherV2"), GatherOpDynamicSlice); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h index 4e8d505e12ff7f377de44e1c077a34d6311fd662..2c80395c56d73adad7dc1679ba6423fbe103605a 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h +++ b/tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h @@ -28,11 +28,13 @@ namespace tensorflow { // Adds to builder an XLA computation that performs a gather on input (of // shape input_shape) keyed on indices (of shape indices_shape). +// +// index_type must be must be DT_INT32 or DT_INT64. xla::ComputationDataHandle XlaComputeGatherDynamicSlice( XlaOpKernelContext* ctx, const xla::ComputationDataHandle& input, const TensorShape& input_shape, const xla::ComputationDataHandle& indices, - const TensorShape& indices_shape, DataType dtype, - xla::ComputationBuilder* builder); + const TensorShape& indices_shape, int64 axis, DataType dtype, + DataType index_type, xla::ComputationBuilder* builder); } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc deleted file mode 100644 index 33b1b087d00d8263cd80f7d5d879401e4ed6c0fb..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int32.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int32_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int32* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int32_xla_impl(float* out, void** data) { - tensorflow::gather_float_int32_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc b/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc deleted file mode 100644 index 5e2d872ce0b28ab479c73ed1fea5f32804c21e22..0000000000000000000000000000000000000000 --- a/tensorflow/compiler/tf2xla/kernels/gather_op_kernel_float_int64.cc +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" -#include "tensorflow/core/framework/tensor_types.h" -#include "tensorflow/core/kernels/gather_functor.h" -#include "tensorflow/core/platform/dynamic_annotations.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { - -EIGEN_STRONG_INLINE void gather_float_int64_xla_impl(float* out, void** data) { - // data is managed by the JIT code so msan can't tell it's initialized. - TF_ANNOTATE_MEMORY_IS_INITIALIZED(data, 7 * sizeof(void*)); - - int64 indices_size = *static_cast(data[1]); - int64 params_x = *static_cast(data[2]); - int64 params_y = *static_cast(data[3]); - int64 params_z = *static_cast(data[4]); - - float* in = static_cast(data[5]); - - int64* indices = static_cast(data[6]); - Eigen::DSizes in_eig_sizes; - in_eig_sizes[0] = params_x; - in_eig_sizes[1] = params_y; - in_eig_sizes[2] = params_z; - tensorflow::TTypes::ConstTensor in_eig(in, in_eig_sizes); - - Eigen::DSizes indices_eig_sizes; - indices_eig_sizes[0] = indices_size; - tensorflow::TTypes::ConstFlat indices_eig(indices, indices_eig_sizes); - - Eigen::DSizes out_eig_sizes; - out_eig_sizes[0] = params_x; - out_eig_sizes[1] = indices_size; - out_eig_sizes[2] = params_z; - tensorflow::TTypes::Tensor out_eig(out, out_eig_sizes); - - tensorflow::functor::GatherFunctorCPU f; - const int64 bad_i = f(in_eig, indices_eig, out_eig); - if (bad_i != -1) { - tensorflow::XlaLocalRuntimeContext* runtime_context = - static_cast(data[0]); - runtime_context->error = true; - runtime_context->error_msg = "Invalid index for gather"; - for (int i = 0; i < out_eig.size(); ++i) out[i] = 0; - } -} - -} // namespace tensorflow - -// Implements gather on CPU. This is called by an XLA custom call, set up by -// gather_op.cc. -extern "C" void TF_EXPORT gather_float_int64_xla_impl(float* out, void** data) { - tensorflow::gather_float_int64_xla_impl(out, data); -} diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc index 87d3d64a4e9c07b8effce7583c4189b8c737d433..d2b1f7913ecc9113284827b53de8fb0e5b711322 100644 --- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc @@ -24,7 +24,9 @@ class IdentityOp : public XlaOpKernel { explicit IdentityOp(OpKernelConstruction* context) : XlaOpKernel(context) {} void Compile(XlaOpKernelContext* ctx) override { - ctx->SetOutput(0, ctx->Input(0)); + for (int i = 0; i < ctx->num_inputs(); ++i) { + ctx->SetOutput(i, ctx->Input(i)); + } } private: @@ -35,6 +37,7 @@ class IdentityOp : public XlaOpKernel { // dummy operator using CompilationOnly(). REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp); +REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp); REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp); REGISTER_XLA_OP(Name("StopGradient"), IdentityOp); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.cc b/tensorflow/compiler/tf2xla/kernels/index_ops.cc index 6be66cf66ec19cad33858f36a3239048efce9de3..e0dc1870f2a4934c35163f0cc10196e8fcbed9be 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.cc @@ -15,10 +15,14 @@ limitations under the License. // Native XLA implementations of indexing ops. +#include "tensorflow/compiler/tf2xla/kernels/index_ops.h" + #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/core/framework/kernel_def_builder.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" @@ -27,115 +31,66 @@ limitations under the License. #include "tensorflow/core/kernels/bounds_check.h" namespace tensorflow { +XlaArgMinMaxOp::XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min) + : XlaOpKernel(ctx), is_min_(is_min) {} + +void XlaArgMinMaxOp::Compile(XlaOpKernelContext* ctx) { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape dimension_shape = ctx->InputShape(1); + + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape), + errors::InvalidArgument( + "dim must be a scalar, but received tensor of shape: ", + dimension_shape.DebugString())); + + int64 dim; + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &dim)); + + const int input_dims = input_shape.dims(); + const int axis = dim < 0 ? dim + input_dims : dim; + + OP_REQUIRES( + ctx, axis >= 0 && axis < input_dims, + errors::InvalidArgument("Expected dimension in the range [", -input_dims, + ", ", input_dims, "), but got ", dim)); + const int64 axis_size = input_shape.dim_size(axis); + OP_REQUIRES( + ctx, axis_size > 0, + errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ", + input_shape.DebugString())); + + DataType index_type = output_type(0); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + + xla::ComputationDataHandle output; + if (is_min_) { + OP_REQUIRES_OK(ctx, + XlaHelpers::ArgMin(b, ctx, input, input_shape, input_type(0), + index_type, axis, &output)); + } else { + OP_REQUIRES_OK(ctx, + XlaHelpers::ArgMax(b, ctx, input, input_shape, input_type(0), + index_type, axis, &output)); + } + + ctx->SetOutput(0, output); +} + +XlaArgMaxOp::XlaArgMaxOp(OpKernelConstruction* ctx) + : XlaArgMinMaxOp(ctx, /*is_min=*/false) {} +REGISTER_XLA_OP(Name("ArgMax").Device(DEVICE_GPU_XLA_JIT), XlaArgMaxOp); + namespace { -// The logic below uses a custom-call to implement argmax. -// -// TODO(toddw): We can implement argmax using existing XLA ops. The idea is -// to use SelectAndScatter to create a tensor initialized to 0, where the max -// value along dim is set to 1. Then take the dot-product of that against a -// vector of indices [0,dim_size), which yields the result. As a detail, we -// might need to reshape before and afterwards, since the XLA Dot operator -// only performs the sum of products over dimension 0. -// -// rs = Reshape(input, ...) // reshape so dim is inner-most -// one_max = SelectAndScatter(rs, greater_than, -// {1,1,...,dim_size}, {1,1,...,dim_size}, -// VALID, [1], 0, add) -// indices = [0,1,2,...,dim_size-1] -// max_index = Dot(one_max, indices) -// result = Reshape(max_index, ...) // reshape back to original -// -// Also see b/29507024 for first-class XLA support for indexing ops. - -class ArgMaxOp : public XlaOpKernel { +class XlaArgMinOp : public XlaArgMinMaxOp { public: - explicit ArgMaxOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} - - void Compile(XlaOpKernelContext* ctx) override { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape dimension_shape = ctx->InputShape(1); - OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape), - errors::InvalidArgument( - "dim must be a scalar, but received tensor of shape: ", - dimension_shape.DebugString())); - - // We require that the dimension argument is a constant, since it lets us - // dispatch to a specialized custom-call function without any run-time - // overhead, when compiling ahead-of-time. - // - // TODO(toddw): We could remove this requirement if necessary; we'd also - // need to update const_analysis.cc. However it seems likely that a native - // XLA op would have the same requirement. - xla::Literal literal; - OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); - const int32 dim = literal.Get({}); - OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); - OP_REQUIRES( - ctx, dim < input_shape.dims(), - errors::InvalidArgument("dim must be < input rank (", - input_shape.dims(), "), but got: ", dim)); - const int64 dim_size = input_shape.dim_size(dim); - OP_REQUIRES( - ctx, dim_size > 0, - errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", - input_shape.DebugString())); - - // The output shape is the input shape contracted along dim. - TensorShape output_shape; - for (int d = 0; d < input_shape.dims() - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); - } - - // For now we use a custom-call, only for the 1d and 2d cases. - OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), - errors::InvalidArgument( - "ArgMax implementation requires a CustomCall on CPU")); - xla::ComputationBuilder& b = *ctx->builder(); - - // XLA passes to the function, so it is not included here. - std::vector args; - args.push_back(ctx->Input(0)); - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(input_shape.dim_sizes()))); - if (input_shape.dims() > 1) { - // Don't bother passing the output shape and dim for the 1d case, since - // the shape is always a scalar and the dim is always 0. - args.push_back(b.ConstantLiteral( - *xla::Literal::CreateR1(output_shape.dim_sizes()))); - args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); - } - - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); - - // Tell XLA to call the custom code, defined in - // index_ops_kernel_argmax_float_1d.cc. - xla::ComputationDataHandle output; - switch (input_shape.dims()) { - case 1: - output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); - break; - case 2: - output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); - break; - default: - OP_REQUIRES(ctx, false, - errors::Unimplemented( - "Argmax is only implemented for 1d and 2d tensors" - ", but got shape: ", - input_shape.DebugString())); - } - ctx->SetOutput(0, output); - } - - private: - TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxOp); + explicit XlaArgMinOp(OpKernelConstruction* ctx); }; - -REGISTER_XLA_OP( - Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), - ArgMaxOp); +XlaArgMinOp::XlaArgMinOp(OpKernelConstruction* ctx) + : XlaArgMinMaxOp(ctx, /*is_min=*/true) {} +REGISTER_XLA_OP(Name("ArgMin"), XlaArgMinOp); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops.h b/tensorflow/compiler/tf2xla/kernels/index_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..ef2b9e6b6ebda921764de768fda0d20c20a765e2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/index_ops.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Declarations of the ArgMax/ArgMin ops using a pure XLA implementation. + +#ifndef TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ + +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { + +class XlaArgMinMaxOp : public XlaOpKernel { + public: + explicit XlaArgMinMaxOp(OpKernelConstruction* ctx, bool is_min); + void Compile(XlaOpKernelContext* ctx) override; + + private: + const bool is_min_; // Are we computing ArgMin (true) or ArgMax (false)? +}; + +class XlaArgMaxOp : public XlaArgMinMaxOp { + public: + explicit XlaArgMaxOp(OpKernelConstruction* ctx); +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_KERNELS_INDEX_OPS_H_ diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..20946e247a9459d7c8a0d8a666fef24bd32838f2 --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -0,0 +1,121 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Native XLA implementations of indexing ops. + +#include "tensorflow/compiler/tf2xla/type_util.h" +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" +#include "tensorflow/core/framework/kernel_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" + +namespace tensorflow { +namespace { + +// The logic below uses a custom-call to implement argmax. +// +// Also see b/29507024 for first-class XLA support for indexing ops. +class ArgMaxCustomCallOp : public XlaOpKernel { + public: + explicit ArgMaxCustomCallOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape dimension_shape = ctx->InputShape(1); + OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(dimension_shape), + errors::InvalidArgument( + "dim must be a scalar, but received tensor of shape: ", + dimension_shape.DebugString())); + + // We require that the dimension argument is a constant, since it lets us + // dispatch to a specialized custom-call function without any run-time + // overhead, when compiling ahead-of-time. + xla::Literal literal; + OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &literal)); + const int32 dim = literal.Get({}); + OP_REQUIRES(ctx, dim >= 0, errors::InvalidArgument("dim must be >= 0")); + OP_REQUIRES( + ctx, dim < input_shape.dims(), + errors::InvalidArgument("dim must be < input rank (", + input_shape.dims(), "), but got: ", dim)); + const int64 dim_size = input_shape.dim_size(dim); + OP_REQUIRES( + ctx, dim_size > 0, + errors::InvalidArgument("Reduction axis ", dim, " is empty in shape: ", + input_shape.DebugString())); + + // The output shape is the input shape contracted along dim. + TensorShape output_shape; + for (int d = 0; d < input_shape.dims() - 1; ++d) { + output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + } + + // For now we use a custom-call, only for the 1d and 2d cases. + OP_REQUIRES(ctx, XlaContext::Get(ctx).allow_cpu_custom_calls(), + errors::InvalidArgument( + "ArgMax implementation requires a CustomCall on CPU")); + xla::ComputationBuilder& b = *ctx->builder(); + + // XLA passes to the function, so it is not included here. + std::vector args; + args.push_back(ctx->Input(0)); + args.push_back(b.ConstantLiteral( + *xla::Literal::CreateR1(input_shape.dim_sizes()))); + if (input_shape.dims() > 1) { + // Don't bother passing the output shape and dim for the 1d case, since + // the shape is always a scalar and the dim is always 0. + args.push_back(b.ConstantLiteral( + *xla::Literal::CreateR1(output_shape.dim_sizes()))); + args.push_back(b.ConstantLiteral(*xla::Literal::CreateR0(dim))); + } + + xla::Shape xla_shape = + xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + + // Tell XLA to call the custom code, defined in + // index_ops_kernel_argmax_float_1d.cc. + xla::ComputationDataHandle output; + switch (input_shape.dims()) { + case 1: + output = b.CustomCall("argmax_float_1d_xla_impl", args, xla_shape); + break; + case 2: + output = b.CustomCall("argmax_float_2d_xla_impl", args, xla_shape); + break; + default: + OP_REQUIRES(ctx, false, + errors::Unimplemented( + "Argmax is only implemented for 1d and 2d tensors" + ", but got shape: ", + input_shape.DebugString())); + } + ctx->SetOutput(0, output); + } + + private: + TF_DISALLOW_COPY_AND_ASSIGN(ArgMaxCustomCallOp); +}; + +REGISTER_XLA_OP( + Name("ArgMax").TypeConstraint("T", DT_FLOAT).Device(DEVICE_CPU_XLA_JIT), + ArgMaxCustomCallOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc index afbd64ca5038378d48744d6d773e0dfb1376e1f9..47cf8c6675bc120653c2a5ab6d4b07376dc382ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_1d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -47,3 +48,5 @@ EIGEN_STRONG_INLINE void argmax_float_1d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_1d_xla_impl(void* out, void** data) { tensorflow::argmax_float_1d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_1d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc index 841ff2f4df79fdd790ee3aace9e38aaeb01a3080..9b83392d8fbe461970603fbadee76e8d71b1ebd0 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_kernel_argmax_float_2d.cc @@ -16,6 +16,7 @@ limitations under the License. #define EIGEN_USE_THREADS #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/platform/dynamic_annotations.h" #include "tensorflow/core/platform/macros.h" @@ -49,3 +50,5 @@ EIGEN_STRONG_INLINE void argmax_float_2d_xla_impl(void* out, void** data) { extern "C" void TF_EXPORT argmax_float_2d_xla_impl(void* out, void** data) { tensorflow::argmax_float_2d_xla_impl(out, data); } + +REGISTER_CUSTOM_CALL_TARGET(argmax_float_2d_xla_impl); diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc index 5c799a0e4f86db04dc966411e0c917387186ce59..fcef497e5845d9080bc83b54e92dcf2fdecf5f12 100644 --- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc @@ -23,6 +23,9 @@ limitations under the License. namespace tensorflow { namespace { +constexpr std::array kMatmulTypes = { + {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}}; + class MatMulOp : public XlaOpKernel { public: explicit MatMulOp(OpKernelConstruction* ctx, bool is_sparse = false) @@ -73,7 +76,7 @@ class MatMulOp : public XlaOpKernel { bool transpose_b_; }; -REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kFloatTypes), MatMulOp); +REGISTER_XLA_OP(Name("MatMul").TypeConstraint("T", kMatmulTypes), MatMulOp); class SparseMatMulOp : public MatMulOp { public: diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 66b99665cbefd9ffd2acabe6eb296f485ca6a59d..2421825ead17a3acee9f145f00904d382fb656f4 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -140,7 +140,7 @@ class TruncatedNormalOp : public XlaOpKernel { xla::ComputationBuilder* b) { xla::ComputationDataHandle too_large = b->Gt(candidate, two_sd(false, b)); xla::ComputationDataHandle too_small = b->Lt(candidate, two_sd(true, b)); - return b->LogicalOr(too_large, too_small); + return b->Or(too_large, too_small); }; // The algorithm we're using is roughly: diff --git a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc index dae2eb9d2a92ef8d4eabb8d6f9a79758c42d446d..647b6274083cf8886af6c451b746416445a4a2b2 100644 --- a/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/reduction_ops.cc @@ -129,7 +129,7 @@ class AllOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalAnd(scalar_lhs, scalar_rhs); + builder->And(scalar_lhs, scalar_rhs); } }; @@ -147,7 +147,7 @@ class AnyOp : public XlaReductionOp { void BuildReducer(xla::ComputationBuilder* builder, const xla::ComputationDataHandle& scalar_lhs, const xla::ComputationDataHandle& scalar_rhs) override { - builder->LogicalOr(scalar_lhs, scalar_rhs); + builder->Or(scalar_lhs, scalar_rhs); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/relu_op.cc b/tensorflow/compiler/tf2xla/kernels/relu_op.cc index a137d28118e6b4c66c70253817be9b3f0b75088a..12a35529992e6160566046dd28f9321c88afec91 100644 --- a/tensorflow/compiler/tf2xla/kernels/relu_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/relu_op.cc @@ -77,9 +77,9 @@ class Relu6GradOp : public XlaOpKernel { b->Broadcast(XlaHelpers::Zero(b, input_type(0)), shape.dim_sizes()); const auto six = b->Broadcast( XlaHelpers::IntegerLiteral(b, input_type(0), 6), shape.dim_sizes()); - auto out = b->Select( - b->LogicalAnd(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), - ctx->Input(0), zero); + auto out = + b->Select(b->And(b->Lt(ctx->Input(1), six), b->Gt(ctx->Input(1), zero)), + ctx->Input(0), zero); ctx->SetOutput(0, out); } }; diff --git a/tensorflow/compiler/tf2xla/kernels/retval_op.cc b/tensorflow/compiler/tf2xla/kernels/retval_op.cc index 462267d1504f16a5fc1f34f5804649416699005a..c283e3b02c2676785952e3e17bffa671b0dabc1e 100644 --- a/tensorflow/compiler/tf2xla/kernels/retval_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/retval_op.cc @@ -60,7 +60,13 @@ class RetvalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal)); OP_REQUIRES_OK(ctx, tc.AddConstRetval(index_, dtype_, literal)); } else { - tc.AddRetval(index_, dtype_, input); + // The core from which a return value is returned depends on the core + // assignment of the input to the retval .Since we can't change the core + // assignment of as this point, create a tuple/get-tuple-element + // combination so that the core will be set on them. + auto tuple_elem = + ctx->builder()->GetTupleElement(ctx->builder()->Tuple({input}), 0); + tc.AddRetval(index_, dtype_, tuple_elem); } } } diff --git a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc index ed818c56ed0e6fa41374234d6f6712a2bbda94e2..5172781c0d05b6682fe92086654e3b86961949ee 100644 --- a/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/sendrecv_ops.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" diff --git a/tensorflow/compiler/tf2xla/kernels/slice_op.cc b/tensorflow/compiler/tf2xla/kernels/slice_op.cc index 482c54a40cfe2f600b36344dff091481a93417a0..fbe8c78d8fb5f800967942555531a50937cad0ca 100644 --- a/tensorflow/compiler/tf2xla/kernels/slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/slice_op.cc @@ -35,88 +35,82 @@ class SliceOp : public XlaOpKernel { explicit SliceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { - bool is_identity = true; + const TensorShape input_shape = ctx->InputShape(0); + const TensorShape begin_tensor_shape = ctx->InputShape(1); + const TensorShape size_tensor_shape = ctx->InputShape(2); + + OP_REQUIRES( + ctx, + IsLegacyVector(begin_tensor_shape) && + IsLegacyVector(size_tensor_shape) && + begin_tensor_shape.num_elements() == input_shape.dims() && + size_tensor_shape.num_elements() == input_shape.dims(), + errors::InvalidArgument( + "Expected begin and size arguments to be 1-D tensors of size ", + input_shape.dims(), ", but got shapes ", + begin_tensor_shape.DebugString(), " and ", + size_tensor_shape.DebugString(), " instead.")); + + const int input_dims = input_shape.dims(); + std::vector begin; std::vector size; - SharedValidation(ctx, &is_identity, &begin, &size); - if (!ctx->status().ok()) return; - - if (is_identity) { - VLOG(1) << "Slice identity"; - ctx->SetOutput(0, ctx->Input(0)); - return; - } - - // slice will be an empty handle if the output has no elements. - CHECK_EQ(begin.size(), size.size()); - std::vector limits; - limits.reserve(begin.size()); - for (int i = 0; i < begin.size(); ++i) { - limits.push_back(begin[i] + size[i]); - } - std::vector strides(begin.size(), 1); - ctx->SetOutput(0, ctx->builder()->Slice(ctx->Input(0), begin, limits, - strides)); - } - - private: - void SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, - std::vector* begin, std::vector* size); -}; - -void SliceOp::SharedValidation(XlaOpKernelContext* ctx, bool* is_identity, - std::vector* begin, - std::vector* size) { - const TensorShape input_shape = ctx->InputShape(0); - const TensorShape begin_tensor_shape = ctx->InputShape(1); - const TensorShape size_tensor_shape = ctx->InputShape(2); - - OP_REQUIRES( - ctx, - IsLegacyVector(begin_tensor_shape) && IsLegacyVector(size_tensor_shape) && - begin_tensor_shape.num_elements() == input_shape.dims() && - size_tensor_shape.num_elements() == input_shape.dims(), - errors::InvalidArgument( - "Expected begin and size arguments to be 1-D tensors of size ", - input_shape.dims(), ", but got shapes ", - begin_tensor_shape.DebugString(), " and ", - size_tensor_shape.DebugString(), " instead.")); - - const int input_dims = input_shape.dims(); - - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, begin)); - OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, size)); - for (int i = 0; i < input_dims; ++i) { - if ((*size)[i] == -1) { - // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". - (*size)[i] = input_shape.dim_size(i) - (*begin)[i]; - } - } - - *is_identity = true; - for (int i = 0; i < input_dims; ++i) { - int64 b = (*begin)[i]; - int64 s = (*size)[i]; - if (input_shape.dim_size(i) == 0) { - OP_REQUIRES(ctx, b == 0 && s == 0, - errors::InvalidArgument( - "Expected begin[", i, "] == 0 (got ", b, ") and size[", i, - "] == 0 ", "(got ", s, ") when ", "input_shape.dim_size(", - i, ") == 0")); + OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(2, &size)); + if (ctx->ConstantInputAsIntVector(1, &begin).ok()) { + // `begin` is a compile-time constant. + for (int i = 0; i < input_dims; ++i) { + if (size[i] == -1) { + // A size[i] of -1 means "all elements from begin[i] to dim_size(i)". + size[i] = input_shape.dim_size(i) - begin[i]; + } + } + + for (int i = 0; i < input_dims; ++i) { + int64 b = begin[i]; + int64 s = size[i]; + if (input_shape.dim_size(i) == 0) { + OP_REQUIRES(ctx, b == 0 && s == 0, + errors::InvalidArgument( + "Expected begin[", i, "] == 0 (got ", b, + ") and size[", i, "] == 0 ", "(got ", s, ") when ", + "input_shape.dim_size(", i, ") == 0")); + } else { + OP_REQUIRES(ctx, 0 <= b && b <= input_shape.dim_size(i), + errors::InvalidArgument("Expected begin[", i, "] in [0, ", + input_shape.dim_size(i), + "], but got ", b)); + OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_shape.dim_size(i) - b, + "], but ", "got ", s)); + } + } + + std::vector limits; + limits.reserve(begin.size()); + for (int i = 0; i < begin.size(); ++i) { + limits.push_back(begin[i] + size[i]); + } + std::vector strides(begin.size(), 1); + ctx->SetOutput( + 0, ctx->builder()->Slice(ctx->Input(0), begin, limits, strides)); } else { - OP_REQUIRES( - ctx, 0 <= b && b <= input_shape.dim_size(i), - errors::InvalidArgument("Expected begin[", i, "] in [0, ", - input_shape.dim_size(i), "], but got ", b)); - OP_REQUIRES(ctx, 0 <= s && b + s <= input_shape.dim_size(i), - errors::InvalidArgument("Expected size[", i, "] in [0, ", - input_shape.dim_size(i) - b, - "], but ", "got ", s)); + // `begin` is not a compile-time constant. + for (int i = 0; i < input_dims; ++i) { + OP_REQUIRES(ctx, 0 <= size[i], + errors::InvalidArgument( + "XLA compilation of Slice operator with negative sizes " + "requires that 'begin' is a compile-time constant.")); + OP_REQUIRES(ctx, size[i] <= input_shape.dim_size(i), + errors::InvalidArgument("Expected size[", i, "] in [0, ", + input_shape.dim_size(i), "], but ", + "got ", size[i])); + } + ctx->SetOutput( + 0, ctx->builder()->DynamicSlice(ctx->Input(0), ctx->Input(1), size)); } - const bool take_all = (b == 0) && (s == input_shape.dim_size(i)); - (*is_identity) &= take_all; } -} +}; REGISTER_XLA_OP(Name("Slice"), SliceOp); diff --git a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc index a0d8ab4d73f7491fe96299c6cdc918f00a3d7a97..750a4c2dec8154f97f307978b3d8884271292279 100644 --- a/tensorflow/compiler/tf2xla/kernels/softmax_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/softmax_op.cc @@ -202,7 +202,7 @@ class SparseSoftmaxXentWithLogitsOp : public XlaOpKernel { // NaN otherwise; then add that vector to the labels to force out-of-range // values to NaNs. xla::ComputationDataHandle nan_or_zero = builder->Select( - builder->LogicalAnd( + builder->And( builder->Le(XlaHelpers::Zero(builder, indices_type), indices), builder->Lt(indices, XlaHelpers::IntegerLiteral( builder, indices_type, depth))), diff --git a/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..89befda346ec06fec23ab1d1c9d910ded8cd806d --- /dev/null +++ b/tensorflow/compiler/tf2xla/kernels/spacetodepth_op.cc @@ -0,0 +1,96 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_helpers.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" +#include "tensorflow/compiler/tf2xla/xla_op_registry.h" + +namespace tensorflow { +namespace { + +class SpaceToDepthOp : public XlaOpKernel { + public: + explicit SpaceToDepthOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("block_size", &block_size_)); + OP_REQUIRES( + ctx, block_size_ > 1, + errors::InvalidArgument("Block size should be > 1: ", block_size_)); + } + + void Compile(XlaOpKernelContext* ctx) override { + const TensorShape input_tensor_shape = ctx->InputShape(0); + // The input is presumed to be [batch, height, width, depth] + int input_rank = input_tensor_shape.dims(); + static const int kRequiredDims = 4; + OP_REQUIRES(ctx, kRequiredDims == input_rank, + errors::InvalidArgument("Input rank should be: ", kRequiredDims, + " instead of: ", input_rank)); + const gtl::InlinedVector input_shape = + input_tensor_shape.dim_sizes(); + + xla::ComputationBuilder* b = ctx->builder(); + xla::ComputationDataHandle input = ctx->Input(0); + + // 1. Reshape `input` to `reshaped` of shape: + // + // [batch, + // input_shape[1] / block_size_, block_size_, + // input_shape[2] / block_size_, block_size_, + // depth] + const int block_rank = 2; + for (int i = 0; i < block_rank; ++i) { + OP_REQUIRES(ctx, input_shape[1 + i] % block_size_ == 0, + errors::InvalidArgument( + "input shape[", 1 + i, "]=", input_shape[1 + i], + " is not divisible by block_size=", block_size_)); + } + xla::ComputationDataHandle reshaped = b->Reshape( + input, {input_shape[0], input_shape[1] / block_size_, block_size_, + input_shape[2] / block_size_, block_size_, input_shape[3]}); + + // 2. Permute dimensions of `reshaped` to produce + // `permuted_reshaped` of shape: + // + // [batch, + // input_shape[1] / block_size_, + // input_shape[2] / block_size_, + // block_size_, block_size_, + // depth] + xla::ComputationDataHandle permuted_reshaped = + b->Transpose(reshaped, {0, 1, 3, 2, 4, 5}); + + // 3. Reshape `permuted_reshaped` to flatten `block_shape` into the + // batch dimension, producing an output tensor of shape: + // + // [batch, + // input_shape[1] / block_size_, + // input_shape[2] / block_size_, + // block_size_ * block_size_ * depth] + // + xla::ComputationDataHandle output = b->Reshape( + permuted_reshaped, {input_shape[0], input_shape[1] / block_size_, + input_shape[2] / block_size_, + block_size_ * block_size_ * input_shape[3]}); + + ctx->SetOutput(0, output); + } + + private: + int block_size_; +}; +REGISTER_XLA_OP(Name("SpaceToDepth"), SpaceToDepthOp); + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/split_op.cc b/tensorflow/compiler/tf2xla/kernels/split_op.cc index 44ee81461e5b31f15594c0dfb86f7219f9875768..795eb1794f577e0f7fd2a2068878e540ff0c1a1d 100644 --- a/tensorflow/compiler/tf2xla/kernels/split_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/split_op.cc @@ -33,13 +33,16 @@ class SplitOp : public XlaOpKernel { explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} void Compile(XlaOpKernelContext* ctx) override { + const int32 num_split = num_outputs(); const TensorShape index_shape = ctx->InputShape(0); + const TensorShape input_shape = ctx->InputShape(1); + xla::Literal literal_index; OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &literal_index)); - int32 split_dim; + int32 split_dim_orig; if (index_shape.dims() == 0) { - split_dim = literal_index.Get({}); + split_dim_orig = literal_index.Get({}); } else { OP_REQUIRES( ctx, index_shape.dims() == 1, @@ -49,27 +52,28 @@ class SplitOp : public XlaOpKernel { ctx, index_shape.dim_size(0) == 1, errors::InvalidArgument("split_index input to Split Op must be a " "scalar or a vector with 1 element")); - split_dim = literal_index.Get({0}); + split_dim_orig = literal_index.Get({0}); } - const int32 num_split = num_outputs(); - const TensorShape input_shape = ctx->InputShape(1); - - OP_REQUIRES( - ctx, 0 <= split_dim && split_dim < input_shape.dims(), - errors::InvalidArgument("0 <= split_dim < number of input dimensions (", - input_shape.dims(), "), but got ", split_dim)); + int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims() + : split_dim_orig; + OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(), + errors::InvalidArgument("-input rank(-", input_shape.dims(), + ") <= split_dim < input rank (", + input_shape.dims(), "), but got ", + split_dim_orig)); OP_REQUIRES( ctx, num_split > 0, errors::InvalidArgument( "Number of ways to split should be > 0, but got ", num_split)); - OP_REQUIRES(ctx, input_shape.dim_size(split_dim) % num_split == 0, - errors::InvalidArgument( - "Number of ways to split should evenly divide the split " - "dimension, but got split_dim ", - split_dim, " (size = ", input_shape.dim_size(split_dim), - ") ", "and num_split ", num_split)); + OP_REQUIRES( + ctx, input_shape.dim_size(split_dim) % num_split == 0, + errors::InvalidArgument( + "Number of ways to split should evenly divide the split " + "dimension, but got split_dim ", + split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ", + "and num_split ", num_split)); // All the slices are the same size: this is the size along the // split dimension. diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc index c42d8b97eafff6270ff3531c342439a173a1d501..351fda251798e43b607fb445f2c98abd57b3d86b 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_array_ops.cc @@ -307,11 +307,12 @@ class TensorArrayGatherOp : public XlaOpKernel { OP_REQUIRES(ctx, indices_shape.dims() == 1, errors::InvalidArgument("indices must be rank 1")); auto indices = ctx->Input(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle ta = resource->value; xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, ta, ta_shape, indices, indices_shape, dtype_, b); + ctx, ta, ta_shape, indices, indices_shape, 0, dtype_, index_type, b); ctx->SetOutput(0, gather); } diff --git a/tensorflow/compiler/tf2xla/kernels/training_ops.cc b/tensorflow/compiler/tf2xla/kernels/training_ops.cc index 82ae0df5cc501cf1b51c2b25b9330d582fbdc44c..5534d1bfa1338c7fe3647cd6aa281c4907dfdf8c 100644 --- a/tensorflow/compiler/tf2xla/kernels/training_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/training_ops.cc @@ -37,8 +37,9 @@ class ResourceApplyGradientDescent : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle)); } }; -REGISTER_XLA_OP(Name("ResourceApplyGradientDescent"), - ResourceApplyGradientDescent); +REGISTER_XLA_OP( + Name("ResourceApplyGradientDescent").TypeConstraint("T", kFloatTypes), + ResourceApplyGradientDescent); class ResourceApplyMomentum : public XlaOpKernel { public: @@ -109,7 +110,8 @@ class ResourceApplyMomentum : public XlaOpKernel { private: bool use_nesterov_; }; -REGISTER_XLA_OP(Name("ResourceApplyMomentum"), ResourceApplyMomentum); +REGISTER_XLA_OP(Name("ResourceApplyMomentum").TypeConstraint("T", kFloatTypes), + ResourceApplyMomentum); class ResourceApplyAdagrad : public XlaOpKernel { public: @@ -163,7 +165,8 @@ class ResourceApplyAdagrad : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum)); } }; -REGISTER_XLA_OP(Name("ResourceApplyAdagrad"), ResourceApplyAdagrad); +REGISTER_XLA_OP(Name("ResourceApplyAdagrad").TypeConstraint("T", kFloatTypes), + ResourceApplyAdagrad); class ResourceApplyAdam : public XlaOpKernel { public: @@ -263,7 +266,8 @@ class ResourceApplyAdam : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyAdam"), ResourceApplyAdam); +REGISTER_XLA_OP(Name("ResourceApplyAdam").TypeConstraint("T", kFloatTypes), + ResourceApplyAdam); class ResourceApplyRMSProp : public XlaOpKernel { public: @@ -362,7 +366,8 @@ class ResourceApplyRMSProp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->AssignVariable(2, type, new_mom)); } }; -REGISTER_XLA_OP(Name("ResourceApplyRMSProp"), ResourceApplyRMSProp); +REGISTER_XLA_OP(Name("ResourceApplyRMSProp").TypeConstraint("T", kFloatTypes), + ResourceApplyRMSProp); void CompileFtrl(XlaOpKernelContext* ctx, DataType dtype, bool has_l2_shrinkage) { @@ -500,7 +505,8 @@ class ResourceApplyFtrl : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrl"), ResourceApplyFtrl); +REGISTER_XLA_OP(Name("ResourceApplyFtrl").TypeConstraint("T", kFloatTypes), + ResourceApplyFtrl); class ResourceApplyFtrlV2 : public XlaOpKernel { public: @@ -515,7 +521,8 @@ class ResourceApplyFtrlV2 : public XlaOpKernel { private: DataType dtype_; }; -REGISTER_XLA_OP(Name("ResourceApplyFtrlV2"), ResourceApplyFtrlV2); +REGISTER_XLA_OP(Name("ResourceApplyFtrlV2").TypeConstraint("T", kFloatTypes), + ResourceApplyFtrlV2); } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc index 6b8f5ec7b33cd448a7b06c5dfe4aac288e53e9c9..a266e9013c41b88788dbc99849f01c09f3d61348 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops.cc @@ -41,6 +41,12 @@ namespace { }; \ REGISTER_XLA_OP(Name(#NAME), NAME##Op); +XLAJIT_MAKE_UNARY(ComplexAbs, b->Abs(x)); + +XLAJIT_MAKE_UNARY(Angle, b->Atan2(b->Imag(x), b->Real(x))); + +XLAJIT_MAKE_UNARY(Conj, b->Conj(x)); + // Return x if x>0, otherwise -x. XLAJIT_MAKE_UNARY(Abs, b->Abs(x)); @@ -87,7 +93,8 @@ XLAJIT_MAKE_UNARY(Log, b->Log(x)); // TODO(b/34703906): use a more accurate implementation of log1p. XLAJIT_MAKE_UNARY(Log1p, b->Log(b->Add(XlaHelpers::One(b, input_type(0)), x))); -XLAJIT_MAKE_UNARY(LogicalNot, b->LogicalNot(x)); +XLAJIT_MAKE_UNARY(Invert, b->Not(x)); +XLAJIT_MAKE_UNARY(LogicalNot, b->Not(x)); XLAJIT_MAKE_UNARY(Neg, b->Neg(x)); // Implements Banker's rounding: numbers that are equidistant between two @@ -104,9 +111,9 @@ static xla::ComputationDataHandle Round(xla::ComputationBuilder* b, auto nearest_even_int = b->Sub(round_val, b->Mul(two, b->Floor(b->Mul(half, x)))); auto is_odd = b->Eq(nearest_even_int, one); - return b->Select(b->LogicalOr(b->Gt(fraction, half), - b->LogicalAnd(b->Eq(fraction, half), is_odd)), - b->Add(round_val, one), round_val); + return b->Select( + b->Or(b->Gt(fraction, half), b->And(b->Eq(fraction, half), is_odd)), + b->Add(round_val, one), round_val); } XLAJIT_MAKE_UNARY(Rint, Round(b, input_type(0), x)); @@ -129,8 +136,28 @@ XLAJIT_MAKE_UNARY(Sign, b->Sign(x)); XLAJIT_MAKE_UNARY(Sinh, b->Mul(b->Sub(b->Exp(x), b->Exp(b->Neg(x))), XlaHelpers::FloatLiteral(b, input_type(0), 0.5))); -XLAJIT_MAKE_UNARY(Softplus, - b->Log(b->Add(b->Exp(x), XlaHelpers::One(b, input_type(0))))); + +static xla::ComputationDataHandle Softplus( + xla::ComputationBuilder* b, DataType dtype, + const xla::ComputationDataHandle& features) { + xla::ComputationDataHandle threshold = + b->Add(b->Log(XlaHelpers::Epsilon(b, dtype)), + XlaHelpers::FloatLiteral(b, dtype, 2.0)); + // Value above which exp(x) may overflow, but softplus(x) == x + // is within machine epsilon. + xla::ComputationDataHandle too_large = b->Gt(features, b->Neg(threshold)); + // Value below which exp(x) may underflow, but softplus(x) == exp(x) + // is within machine epsilon. + xla::ComputationDataHandle too_small = b->Lt(features, threshold); + xla::ComputationDataHandle features_exp = b->Exp(features); + xla::ComputationDataHandle output = b->Select( + too_large, features, + b->Select(too_small, features_exp, + b->Log(b->Add(features_exp, XlaHelpers::One(b, dtype))))); + return output; +} +XLAJIT_MAKE_UNARY(Softplus, Softplus(b, input_type(0), x)); + // softsign(x) = x / (abs(x) + 1) XLAJIT_MAKE_UNARY(Softsign, b->Div(x, @@ -141,6 +168,9 @@ XLAJIT_MAKE_UNARY(Square, b->Mul(x, x)); XLAJIT_MAKE_UNARY(Tan, b->Div(b->Sin(x), b->Cos(x))); XLAJIT_MAKE_UNARY(Tanh, b->Tanh(x)); +XLAJIT_MAKE_UNARY(Real, b->Real(x)); +XLAJIT_MAKE_UNARY(Imag, b->Imag(x)); + #undef XLAJIT_MAKE_UNARY } // namespace diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc index ecf8e6009dff9c28721c45fd3e95033a46cf37e5..b19ea22f50d2dd44e8d1d81f5930263f364030e1 100644 --- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc @@ -111,9 +111,10 @@ class ResourceGatherOp : public XlaOpKernel { auto indices = ctx->Input(1); auto indices_shape = ctx->InputShape(1); + DataType index_type = ctx->input_type(1); xla::ComputationDataHandle gather = XlaComputeGatherDynamicSlice( - ctx, resource_handle, resource_shape, indices, indices_shape, - resource_dtype, builder); + ctx, resource_handle, resource_shape, indices, indices_shape, 0, + resource_dtype, index_type, builder); ctx->SetOutput(0, gather); } }; diff --git a/tensorflow/compiler/tf2xla/ops/functional_ops.cc b/tensorflow/compiler/tf2xla/ops/functional_ops.cc index c1005405f9a9b09e4a6480332861d0cce2c52291..4a669f8e6eaf644f119f3c0a66f29d9f2c9a9d16 100644 --- a/tensorflow/compiler/tf2xla/ops/functional_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/functional_ops.cc @@ -34,14 +34,41 @@ output = input; While (Cond(output)) { output = Body(output) } input: A list of input tensors whose types are T. output: A list of output tensors whose types are T. cond: A function takes 'input' and returns a tensor. If the tensor is - a scalar of non-boolean, the scalar is converted to a boolean - according to the following rule: if the scalar is a numerical - value, non-zero means True and zero means False; if the scalar is - a string, non-empty means True and empty means False. If the - tensor is not a scalar, non-emptiness means True and False - otherwise. + a scalar of non-boolean, the scalar is converted to a boolean + according to the following rule: if the scalar is a numerical + value, non-zero means True and zero means False; if the scalar is + a string, non-empty means True and empty means False. If the + tensor is not a scalar, non-emptiness means True and False + otherwise. body: A function that takes a list of tensors and returns another list of tensors. Both lists have the same types as specified by T. )doc"); +// TODO(b/37549631) setting the If Op to always be stateful is too +// conservative. +REGISTER_OP("XlaIf") + .Input("cond: Tcond") + .Input("inputs: Tin") + .Output("output: Tout") + .Attr("Tcond: type") + .Attr("then_branch: func") + .Attr("else_branch: func") + .Attr("Tin: list(type) >= 0") + .Attr("Tout: list(type) >= 0") + .SetIsStateful() + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +output = cond ? then_branch(inputs) : else_branch(inputs). + +cond: A boolean scalar. +inputs: A list of input tensors. +output: A list of tensors returned by either then_branch(inputs) or + else_branch(inputs). The input shapes of the then_branch and + else_branch must match. +then_branch: A function takes 'inputs' and returns a list of tensors, + whose types are the same as what else_branch returns. +else_branch: A function takes 'inputs' and returns a list of tensors. + whose types are the same as what then_branch returns. +)doc"); + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc index b6947bfe570c75dd0c7c6301b972e2012bae26bd..4b41c16a8b3fdc0c3412c76d29d3ec2b7bdfd0aa 100644 --- a/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/sendrecv_ops.cc @@ -37,7 +37,14 @@ REGISTER_OP("_XLARecv") .Attr("tensor_name: string") .Attr("shape: shape") .SetIsStateful() - .SetShapeFn(shape_inference::UnknownShape) + .SetShapeFn([](shape_inference::InferenceContext* c) { + TensorShape shape_attr; + TF_RETURN_IF_ERROR(c->GetAttr("shape", &shape_attr)); + shape_inference::ShapeHandle s; + TF_RETURN_IF_ERROR(c->MakeShapeFromTensorShape(shape_attr, &s)); + c->set_output(0, s); + return Status::OK(); + }) .Doc(R"doc( Receives the named tensor from another XLA computation. diff --git a/tensorflow/compiler/tf2xla/sharding_util.cc b/tensorflow/compiler/tf2xla/sharding_util.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9c839b61019b92b6de3a77a7bec610ae848a9a4 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +static const char DEVICE_SUFFIX_REPLICATED_CORE[] = "REPLICATED_CORE"; + +static Status CoreOutOfRangeError(int core, int num_cores_per_replica) { + return errors::InvalidArgument( + "Invalid replicated core id: ", core, + "; num_cores_per_replica=", num_cores_per_replica); +} + +xla::StatusOr> +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica) { + if (device_name.empty()) { + return tensorflow::gtl::optional(); + } + + DeviceNameUtils::ParsedName parsed_device; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_device)) { + return errors::InvalidArgument("Malformed assigned device '", device_name, + "'"); + } + if (!parsed_device.has_type || + !StringPiece(parsed_device.type) + .ends_with(DEVICE_SUFFIX_REPLICATED_CORE)) { + return tensorflow::gtl::optional(); + } else { + const int core = parsed_device.id; + if (core < 0 || core >= num_cores_per_replica) { + return CoreOutOfRangeError(core, num_cores_per_replica); + } + return tensorflow::gtl::optional( + xla::ShardingBuilder::AssignDevice(core)); + } +} + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica) { + string device_name = node.assigned_device_name(); + if (device_name.empty()) { + device_name = node.requested_device(); + } + return ParseShardingFromDevice(device_name, num_cores_per_replica); +} +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst) { + string device_name = src.assigned_device_name(); + if (device_name.empty()) { + device_name = src.requested_device(); + } + dst->set_assigned_device_name(device_name); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/sharding_util.h b/tensorflow/compiler/tf2xla/sharding_util.h new file mode 100644 index 0000000000000000000000000000000000000000..f6468bba9f950fec88dcc6b3ec760f014d3a0ef3 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ +#define TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ + +#include + +#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/lib/core/status.h" + +namespace tensorflow { + +// Parses the op sharding from the 'replicated core' device_name . +// Returns an error: +// - if the device name is invalid. +// - the core is parsed and is out of the range [0, num_cores_per_replica). +// +// Otherwise, returns either a non-value or a sharding set as per +// xla:ShardingBuilder::AssignDevice. +xla::StatusOr> +ParseShardingFromDevice(const string& device_name, int num_cores_per_replica); + +xla::StatusOr> +ParseShardingFromDevice(const Node& node, int num_cores_per_replica); + +void SetShardingDeviceAssignmentFromNode(const Node& src, Node* dst); + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_TPU_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/sharding_util_test.cc b/tensorflow/compiler/tf2xla/sharding_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..bff5978237a827cb9650541f2cf6984d9e846796 --- /dev/null +++ b/tensorflow/compiler/tf2xla/sharding_util_test.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/tf2xla/sharding_util.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { + +TEST(CoreUtilTest, ParseShardingFromDevice) { + Graph graph(OpRegistry::Global()); + + auto core_from_sharding = + [](tensorflow::gtl::optional sharding) -> int64 { + if (sharding.has_value() && + sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL) { + return sharding.value().tile_assignment_devices(0); + } else { + return -1; + } + }; + + auto parse_status = ParseShardingFromDevice("", 1); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + parse_status = ParseShardingFromDevice("", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:-1", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:55", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(55, core_from_sharding(parse_status.ValueOrDie())); + + parse_status = ParseShardingFromDevice("/device:A_REPLICATED_CORE:100", 100); + EXPECT_FALSE(parse_status.ok()); + + parse_status = ParseShardingFromDevice("/cpu:0", 100); + TF_EXPECT_OK(parse_status.status()); + EXPECT_EQ(-1, core_from_sharding(parse_status.ValueOrDie())); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index b7213a6cc1e4066f98523ec57681f4c0651f71b5..a14c93a2b9494b89f579bc20ee0510c136f8f01b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -255,11 +255,10 @@ Status CreateXlaArgs(const Graph& graph, Status ConvertGraphToXla(std::unique_ptr graph, xla::Client* client, xla::Computation* computation, bool* requires_runtime_context) { - // Create a device and context to convert the graph into an XLA computation. XlaOpRegistry::RegisterCompilationKernels(); - // Populate the context with args from the graph. for (Node* node : graph->nodes()) { - node->set_assigned_device_name(DEVICE_CPU_XLA_JIT); + node->set_assigned_device_name( + strings::StrCat("/device:", DEVICE_CPU_XLA_JIT)); } std::vector xla_args; TF_RETURN_IF_ERROR(CreateXlaArgs(*graph, &xla_args)); diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 51ce17deb62117ff8c1075160d0bebe6cf1438f1..ecd15652fe84b0c19d2f7fc18f877236547f9be9 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -92,7 +92,7 @@ TEST(ConvertGraphDefToXla, Sum) { client->ExecuteAndTransfer(computation, {x_global.get(), y_global.get()}); TF_EXPECT_OK(result_or.status()); std::unique_ptr result = std::move(result_or.ValueOrDie()); - EXPECT_EQ("(s32[]) (\n42,\n)", result->ToString()); + EXPECT_EQ("(s32[]) (\n42\n)", result->ToString()); } } // namespace diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc index 14e0910cab2c3aa329fe798d199454fd6c5ee6a5..55f2f3149c6ba7bfa18608f961c8a76103a50756 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc @@ -19,7 +19,9 @@ limitations under the License. #include #include +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" #include "tensorflow/core/framework/node_def.pb.h" @@ -29,6 +31,7 @@ limitations under the License. #include "tensorflow/core/graph/tensor_id.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/strcat.h" namespace tensorflow { @@ -250,4 +253,32 @@ string TensorIdToString(const tf2xla::TensorId& id) { return strings::StrCat(id.node_name(), ":", id.output_index()); } +Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) { + int core = -1; + const Node* matching_node = nullptr; + for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) { + if (edge->IsControlEdge()) continue; + const Node* possible_match = out_edges ? edge->dst() : edge->src(); + TF_ASSIGN_OR_RETURN( + tensorflow::gtl::optional sharding, + ParseShardingFromDevice( + *possible_match, + /*num_cores_per_replica=*/std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + const int core_annotation = sharding.value().tile_assignment_devices(0); + if (core == -1 || core > core_annotation) { + core = core_annotation; + matching_node = possible_match; + } + } + } + if (matching_node != nullptr) { + n->set_assigned_device_name(matching_node->assigned_device_name()); + n->set_requested_device(matching_node->requested_device()); + } + return Status::OK(); +} + } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h index a29d0c16f9cfde3c97bfa9cf3165890f83939a43..e5fba8ede7745febbb42c572a7b52247213afc95 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util.h +++ b/tensorflow/compiler/tf2xla/tf2xla_util.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" namespace tensorflow { @@ -45,6 +46,11 @@ Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in, // Returns node:port for the given . string TensorIdToString(const tf2xla::TensorId& id); +// Updates the sharding of based on the sharding of its neighbors. +// If is true, outgoing edges from are considered; else incoming +// edges are considered. +Status SetNodeShardingFromNeighbors(Node* n, bool out_edges); + } // namespace tensorflow #endif // TENSORFLOW_COMPILER_TF2XLA_TF2XLA_UTIL_H_ diff --git a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc index b98c89f284d6a2bfc6d043794a580e60da93617f..436039e154842443f779aba276bc571fc2ab7537 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_util_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_util_test.cc @@ -15,7 +15,13 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/tf2xla_util.h" +#include "tensorflow/cc/framework/ops.h" +#include "tensorflow/cc/ops/data_flow_ops.h" +#include "tensorflow/cc/ops/function_ops.h" +#include "tensorflow/cc/ops/standard_ops.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/core/stringpiece.h" @@ -211,5 +217,52 @@ TEST(PruneGraphDefInto, Basic) { EXPECT_EQ(def.DebugString(), copy.DebugString()); } +TEST(SetNodeShardingFromNeighbors, Basic) { + // Builds a graph that adds two Tensors. + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); + auto c = ops::Add(scope.WithOpName("C"), a, b); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + Node* a_node = nullptr; + Node* b_node = nullptr; + Node* c_node = nullptr; + for (Node* n : graph->nodes()) { + if (n->name() == "A") a_node = n; + if (n->name() == "B") b_node = n; + if (n->name() == "C") c_node = n; + } + + const int num_cores_per_replica = 4; + + a_node->set_assigned_device_name("foo"); + EXPECT_FALSE(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false).ok()); + + // Test where one input to c_node has a device. + a_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:2"); + TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); + auto parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(2, parse_status.ValueOrDie().value().tile_assignment_devices(0)); + + // Test where two inputs to c_node have a device. + b_node->set_assigned_device_name("/device:TPU_REPLICATED_CORE:1"); + TF_ASSERT_OK(SetNodeShardingFromNeighbors(c_node, /*out_edges=*/false)); + parse_status = ParseShardingFromDevice(*c_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); + + // Test setting based on out edges. + TF_ASSERT_OK(SetNodeShardingFromNeighbors(a_node, /*out_edges=*/true)); + parse_status = ParseShardingFromDevice(*a_node, num_cores_per_replica); + TF_ASSERT_OK(parse_status.status()); + ASSERT_TRUE(parse_status.ValueOrDie().has_value()); + EXPECT_EQ(1, parse_status.ValueOrDie().value().tile_assignment_devices(0)); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index b54848f342406c9211c06664fd1f6c0783e0891f..1efbe0ffb17dad5332aa700b2e255d4a99fbef72 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -43,6 +43,12 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_UINT16: *type = xla::U16; return Status::OK(); + case tensorflow::DT_UINT32: + *type = xla::U32; + return Status::OK(); + case tensorflow::DT_UINT64: + *type = xla::U64; + return Status::OK(); case tensorflow::DT_HALF: *type = xla::F16; return Status::OK(); @@ -52,6 +58,9 @@ Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type) { case tensorflow::DT_DOUBLE: *type = xla::F64; return Status::OK(); + case tensorflow::DT_COMPLEX64: + *type = xla::C64; + return Status::OK(); case tensorflow::DT_QUINT8: *type = xla::U8; return Status::OK(); diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.cc b/tensorflow/compiler/tf2xla/xla_compilation_device.cc index 890a9ccb830c75afcb81d28685cc26e4a7ef35f9..4f32c29954b2d809d31ef8c584b6a6c3dcdf5cef 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.cc +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -77,7 +78,8 @@ XlaCompilationDevice::XlaCompilationDevice(const SessionOptions& options, : LocalDevice( options, Device::BuildDeviceAttributes( - "", type, Bytes(256 << 20), DeviceLocality(), + strings::StrCat("/device:", type.type(), ":0"), type, + Bytes(256 << 20), DeviceLocality(), strings::StrCat("device: XLA compilation device ", type.type()))), allocator_(new XlaCompilationAllocator()) {} @@ -97,26 +99,19 @@ void XlaCompilationDevice::Compute(OpKernel* op_kernel, metadata.set_op_name(op_kernel->name()); b->SetOpMetadata(metadata); - DeviceNameUtils::ParsedName parsed; - OP_REQUIRES( - context, - DeviceNameUtils::ParseFullName(op_kernel->requested_device(), &parsed), - errors::Internal("Unable to parse device name: ", - op_kernel->requested_device())); - xla::OpDeviceAssignment assignment; - // If no device ID assignment is found, XLA is free to use whatever device it - // wants. In practice this usually has the effect of placing things on - // device 0. - if (parsed.has_id) { - assignment.set_has_device(true); - assignment.set_device(parsed.id); - } - b->SetDeviceAssignment(assignment); + auto sharding_parse_result = ParseShardingFromDevice( + op_kernel->requested_device(), std::numeric_limits::max()); + OP_REQUIRES_OK(context, sharding_parse_result.status()); + tensorflow::gtl::optional op_sharding = + sharding_parse_result.ValueOrDie(); + // If no sharding metadata is found, XLA is free to use whatever device it + // wants. In practice this usually has the effect of placing things on device + // 0. + xla::ScopedShardingAssignment assign_sharding(b, op_sharding); op_kernel->Compute(context); b->ClearOpMetadata(); - b->ClearDeviceAssignment(); VLOG(4) << "Done"; } diff --git a/tensorflow/compiler/tf2xla/xla_compilation_device.h b/tensorflow/compiler/tf2xla/xla_compilation_device.h index 765683cf1dc64ace2289340846014582faa051aa..6230acd718bc330f178007b575b5119de5b3d4f4 100644 --- a/tensorflow/compiler/tf2xla/xla_compilation_device.h +++ b/tensorflow/compiler/tf2xla/xla_compilation_device.h @@ -34,17 +34,18 @@ namespace tensorflow { // declared. class XlaCompilationAllocator; -// Deliberately don't register the device factory because we *never* -// want soft placement to put Ops on an JIT device. Tests can include -// the tla_jit_test_deps target which registers the factory, and when -// using JIT in practice, the device is created manually not using a -// factory. - // This is a 'dummy' TensorFlow device that is only used to execute a // subgraph of XLA compilation Ops to construct a compiled version // of the subgraph's computation. It has a 'dummy' allocator that // backs each Tensor with metadata indicating the computation the // Tensor represents. +// +// We deliberately don't register a device factory because we *never* +// want placement to put Ops on a compilation device. The device is created +// manually, not using a factory. +// +// XLA compilation is not thread-safe. OpKernels registered on the +// XlaCompilationDevice must not use threads or concurrency. class XlaCompilationDevice : public LocalDevice { public: XlaCompilationDevice(const SessionOptions& options, DeviceType type); diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..b5c17c5273bb15e20184b2fefd93880d4828105e --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -0,0 +1,88 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" + +#include +#include "tensorflow/compiler/aot/runtime.h" + +namespace tensorflow { + +XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, + AllocMode alloc_mode) + : raw_function_(static_data.raw_function), + result_index_(static_data.result_index), + args_(new void*[static_data.num_args]), + temps_(new void*[static_data.num_temps]), + arg_names_(static_data.arg_names), + result_names_(static_data.result_names), + program_shape_(static_data.program_shape) { + // Allocate arg and temp buffers. + if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) { + alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.arg_sizes, static_data.num_args, args_, + /*annotate_initialized=*/false); + } + alloc_temps_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers( + static_data.temp_sizes, static_data.num_temps, temps_, + /*annotate_initialized=*/true); + + // The runtime context is always the last arg, if it is required. + if (static_data.requires_runtime_context) { + args_[static_data.num_args - 1] = &context_; + } +} + +XlaCompiledCpuFunction::~XlaCompiledCpuFunction() { + tensorflow::tfcompile::runtime::FreeContiguous(alloc_args_); + tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_); + delete[] args_; + delete[] temps_; +} + +namespace { + +// Linear search through `names` looking for a match with `name`. Returns -1 if +// the name isn't found, or is empty. +// +// REQUIRES: `names` is a nullptr-terminated array. +int LookupNameIndex(const string& name, const char** names) { + // Hitting this assert means that there is no name-to-index data available; + // for AOT try the setting the tfcompile --gen_name_to_index flag. + assert(names != nullptr); + + constexpr int kNotFound = -1; + if (name.empty()) { + return kNotFound; + } + for (int index = 0; names[index] != nullptr; ++index) { + if (name == names[index]) { + return index; + } + } + return kNotFound; +} + +} // namespace + +int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const { + return LookupNameIndex(name, arg_names_); +} + +int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const { + return LookupNameIndex(name, result_names_); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h new file mode 100644 index 0000000000000000000000000000000000000000..f49a7889222ff989144217ab10b27595f89e4311 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -0,0 +1,223 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/platform/types.h" + +// Forward-declare, rather than include, to reduce code size for users that +// never use this functionality. +namespace xla { +class ProgramShape; +} + +namespace tensorflow { + +// Represents a function compiled by XLA, produced via either JIT or AOT. +// +// The Run method invokes the actual computation, with inputs read from arg +// buffers, and outputs written to result buffers. Each Run call may also use a +// set of temporary buffers for the computation. +// +// By default each instance of this class manages its own arg, result and temp +// buffers. The AllocMode constructor parameter may be used to modify the buffer +// allocation strategy. +// +// Under the default allocation strategy, this class is thread-compatible: +// o Calls to non-const methods require exclusive access to the object. +// o Concurrent calls to const methods are OK, if those calls are made while it +// is guaranteed that no thread may call a non-const method. +class XlaCompiledCpuFunction { + public: + // Type of the raw function, produced by either JIT or AOT. + // + // TODO(toddw): Add support for hlo profiling, and replace std::function with + // a raw function pointer, for some codesize savings. + using RawFunction = std::function; + + // StaticData represents the state necessary to run an XLA-compiled + // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for + // AOT this is backed by data compiled into the object file. + struct StaticData { + // The raw function to call. + RawFunction raw_function; + + // Cardinality and sizes of arg and temp buffers. + const intptr_t* arg_sizes = nullptr; + size_t num_args = 0; + const intptr_t* temp_sizes = nullptr; + size_t num_temps = 0; + + // The 0-based index of the result tuple, in the temp buffers. + size_t result_index = 0; + + // Is the final arg XlaLocalRuntimeContext? + bool requires_runtime_context = false; + + // [Optional] Arrays of arg and result names. These are arrays of C-style + // strings, where the array is terminated by nullptr. + const char** arg_names = nullptr; + const char** result_names = nullptr; + + // [Optional] Arg and result shapes. + const xla::ProgramShape* program_shape = nullptr; + }; + + // AllocMode controls the buffer allocation mode. + enum class AllocMode { + // Allocate all buffers - args, results and temps. + ARGS_RESULTS_AND_TEMPS, + + // Only allocate result and temp buffers. + // Use set_arg_data to set argument buffers before Run is called. + RESULTS_AND_TEMPS_ONLY, + }; + + XlaCompiledCpuFunction( + const StaticData& static_data, + AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS); + virtual ~XlaCompiledCpuFunction(); + + XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete; + XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete; + + // Sets the intra-op thread pool used to run individual ops concurrently. + void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + run_options_.set_intra_op_thread_pool(pool); + context_.thread_pool = pool; + } + + // Runs the computation, with inputs read from arg buffers, and outputs + // written to result buffers. Returns true on success and false on failure. + bool Run() { + context_.error = false; + context_.error_msg.clear(); + raw_function_(temps_[result_index_], &run_options_, + const_cast(args_), temps_); + return !context_.error; + } + + // Returns the error message from the previous failed Run call. + const string& error_msg() const { return context_.error_msg; } + + // ------------------------------ + // Arg methods for managing input buffers. Buffers are in row-major order. + + // Returns the underlying array of argument buffers, where args()[I] is the + // buffer for the positional argument at index I. + void** args() { return args_; } + const void* const* args() const { return args_; } + + // Returns the buffer for the positional argument at the given `index`. + void* arg_data(size_t index) { return args_[index]; } + const void* arg_data(size_t index) const { return args_[index]; } + + // Sets the buffer for the positional argument at the given `index` to `data`. + // Must be called before Run to have an effect. May be called under any + // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be + // called for each positional argument, in order to set the argument buffers. + // + // Allocated memory must be aligned to the size specified by + // tensorflow::tfcompile::runtime::kAlign. If possible, use the functions in + // tensorflow/compiler/aot/runtime.h to ensure correct alignment. + // + // If StaticData.requires_runtime_context==true, the final argument is an + // XlaLocalRuntimeContext, which is managed internally by this class, and + // should not be changed. + // + // Aliasing of argument and result buffers is not allowed, and results in + // undefined behavior. + void set_arg_data(size_t index, void* data) { args_[index] = data; } + + // ------------------------------ + // Result methods for managing output buffers. Buffers are in row-major order. + // Must only be called after a successful Run call. Unlike the arg methods, + // there is no set_resultN_data method. The result buffers are managed + // internally, and may change after each call to Run. + + // Returns the underlying array of result buffers, where results()[I] is the + // buffer for the positional result at index I. + void** results() { return static_cast(temps_[result_index_]); } + const void* const* results() const { + return static_cast(temps_[result_index_]); + } + + // Returns the buffer for the positional result at the given `index`. + void* result_data(size_t index) { return results()[index]; } + const void* result_data(size_t index) const { return results()[index]; } + + // ------------------------------ + // Methods for extracting optional metadata. + + // Returns true iff data is available for the Lookup{Arg,Result}Index methods. + // E.g. the data might not be compiled into the binary for AOT. + bool HasNameIndices() const { + return arg_names_ != nullptr && result_names_ != nullptr; + } + + // Returns the 0-based index for the argument with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupArgIndex(const string& name) const; + + // Returns the 0-based index for the result with the given `name`. + // Returns -1 if the name wasn't found, or data isn't available. + // + // The index remains constant for every instance of XlaCompiledCpuFunction + // generated from the same static data, and might not be cheap to determine. + // Recommended usage is to capture this in a variable for re-use. + int LookupResultIndex(const string& name) const; + + // Returns the shape of the args and results. May return nullptr if the + // program shape isn't available. + const xla::ProgramShape* ProgramShape() const { return program_shape_; } + + private: + const RawFunction raw_function_; + const size_t result_index_; + + // Arrays of argument and temp buffers; entries in args_ may be overwritten by + // the user. + void** args_ = nullptr; + void** temps_ = nullptr; + + // Backing memory for individual arg and temp buffers. + void* alloc_args_ = nullptr; + void* alloc_temps_ = nullptr; + + // Options and context passed to the compiled function. + xla::ExecutableRunOptions run_options_; + tensorflow::XlaLocalRuntimeContext context_; + + // Optional metadata. + const char** arg_names_ = nullptr; + const char** result_names_ = nullptr; + const xla::ProgramShape* program_shape_ = nullptr; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index 0b583b54bf0f80763cee3e563215bf5679583709..48cebdf74c71f974bf075e0255626ec57eb9a149 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -15,14 +15,20 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiler.h" +#include #include +#include "tensorflow/compiler/tf2xla/const_analysis.h" #include "tensorflow/compiler/tf2xla/dump_graph.h" #include "tensorflow/compiler/tf2xla/functionalize_control_flow.h" +#include "tensorflow/compiler/tf2xla/graph_compiler.h" #include "tensorflow/compiler/tf2xla/shape_util.h" +#include "tensorflow/compiler/tf2xla/sharding_util.h" +#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_compilation_device.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/common_runtime/executor.h" @@ -91,7 +97,6 @@ XlaCompiler::XlaCompiler(XlaCompiler::Options options) } local_flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), - FunctionDefLibrary{})); local_pflr_.reset(new ProcessFunctionLibraryRuntime( &device_mgr_, Env::Default(), options.graph_def_version, @@ -126,10 +131,41 @@ static Status GetFunctionBody(const NameAttrList& function, return Status::OK(); } -Status XlaCompiler::CompileFunction( - const XlaCompiler::CompileOptions& options, const NameAttrList& function, - const std::vector& args, - XlaCompiler::CompilationResult* result) { +Status XlaCompiler::FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody) { + // The function may be in either the local_flib_runtime_ or flib_runtime_. + // Look up the function in local first and if it is not found then look up the + // function in flib_runtime_. + auto status = GetFunctionBody(function, local_flib_runtime_, fbody); + if (!status.ok()) { + if (!errors::IsNotFound(status)) { + return status; + } + TF_RETURN_WITH_CONTEXT_IF_ERROR( + GetFunctionBody(function, flib_runtime_, fbody), + "Local lookup failed with: ", status.error_message()); + } + return Status::OK(); +} + +std::unique_ptr XlaCompiler::GetGraph(const FunctionBody* fbody) { + std::unique_ptr graph(new Graph(options_.flib_def)); + CopyGraph(*fbody->graph, graph.get()); + OptimizerOptions opts; + opts.set_do_common_subexpression_elimination(true); + opts.set_do_function_inlining(true); + opts.set_do_constant_folding(true); + GraphOptimizer optimizer(opts); + optimizer.Optimize(flib_runtime_, flib_runtime_->env(), + /*device=*/nullptr, &graph, /*shape_map=*/nullptr); + + return graph; +} + +Status XlaCompiler::CompileFunction(const XlaCompiler::CompileOptions& options, + const NameAttrList& function, + std::vector args, + XlaCompiler::CompilationResult* result) { const string function_id = Canonicalize(function.name(), AttrSlice(&function.attr())); VLOG(1) << "XlaCompiler::CompileFunction " << function_id; @@ -141,15 +177,34 @@ Status XlaCompiler::CompileFunction( } const FunctionBody* fbody; - if (!GetFunctionBody(function, local_flib_runtime_, &fbody).ok()) { - TF_RETURN_IF_ERROR(GetFunctionBody(function, flib_runtime_, &fbody)); - } + TF_RETURN_IF_ERROR(FindFunctionBody(function, &fbody)); - TF_RETURN_IF_ERROR(CheckSignature(fbody->arg_types, args)); + TF_RETURN_WITH_CONTEXT_IF_ERROR( + CheckSignature(fbody->arg_types, args), + "Signature check failure while compiling: ", function.name()); std::unique_ptr graph(new Graph(options_.flib_def)); CopyGraph(*fbody->graph, graph.get()); + // _Arg and _Retval nodes don't exist in the stored subgraph for the function; + // they are added by the function body looked up. Therefore, they don't have + // core assignments here. + // Attempt to assign a core to each _Retval and _Arg. Chooses the + // lowest-numbered core that consumes the argument. We choose the + // lowest-numbered core so the assignment is deterministic. + for (Node* n : graph->nodes()) { + if (StringPiece(n->type_string()) == "_Arg") { + TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/true)); + } + } + // Do _Retval as a second loop, in case the retval's input is an _Arg (which + // may have gotten a device assignment from the first loop). + for (Node* n : graph->nodes()) { + if (StringPiece(n->type_string()) == "_Retval") { + TF_RETURN_IF_ERROR(SetNodeShardingFromNeighbors(n, /*out_edges=*/false)); + } + } + if (VLOG_IS_ON(2)) { VLOG(2) << "XlaCompiler::CompileFunction: " << dump_graph::DumpGraphToFile( @@ -180,7 +235,7 @@ namespace { Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, XlaCompilationDevice* device, FunctionLibraryRuntime* flib, int64 step_id) { - // Resource cleanup is a bit messy. XlaContext is a ref-counted resource; the + // Resource cleanup is a bit messy. XlaContext is a ref-countd resource; the // resource manager takes ownership via Create, and unrefs via Cleanup. We // explicitly add a reference to ensure the refcount at entry is maintained at // all exit points; Create and Cleanup are always called in this function. @@ -188,52 +243,34 @@ Status ExecuteGraph(XlaContext* xla_context, std::unique_ptr graph, // The Executor requires us to use ScopedStepContainer. We wrap it in a // unique_ptr so we can capture the cleanup status in the end. xla_context->Ref(); - Status cleanup_status; + Status status; auto step_container = xla::MakeUnique( - step_id, [&cleanup_status, device](const string& name) { - cleanup_status = device->resource_manager()->Cleanup(name); + step_id, [&status, device](const string& name) { + status = device->resource_manager()->Cleanup(name); }); TF_RETURN_IF_ERROR(device->resource_manager()->Create( step_container->name(), XlaContext::kXlaContextResourceName, xla_context)); - // Create a LocalExecutor that will own and run the graph. - LocalExecutorParams exec_params; - exec_params.device = device; - exec_params.function_library = flib; - exec_params.create_kernel = [flib](const NodeDef& ndef, OpKernel** kernel) { - return flib->CreateKernel(ndef, kernel); - }; - exec_params.delete_kernel = [](OpKernel* kernel) { delete kernel; }; - Executor* exec_ptr = nullptr; - TF_RETURN_IF_ERROR(NewLocalExecutor(exec_params, graph.release(), &exec_ptr)); - std::unique_ptr exec(exec_ptr); - // At this point ownership of the graph has been transferred to exec. - - // Run the graph symbolically, turning the graph into an XLA computation. - Executor::Args exec_args; - exec_args.step_id = step_id; - exec_args.step_container = step_container.get(); - // Run all compilation kernels on the main thread. - exec_args.runner = [](Executor::Args::Closure c) { c(); }; - TF_RETURN_WITH_CONTEXT_IF_ERROR( - exec->Run(exec_args), - "Conversion from TensorFlow graph to XLA computation failed."); - + GraphCompiler graph_compiler(xla_context, device, graph.get(), flib, + step_container.get()); + TF_RETURN_IF_ERROR(graph_compiler.Compile()); // Explicitly clean up the step container, to capture the cleanup status. step_container.reset(); - return cleanup_status; + return Status::OK(); } // Builds XLA computations for each of the arguments to the computation. // `args` are the arguments to the computation. -Status BuildArguments(const std::vector& args, +Status BuildArguments(const Graph& graph, + const std::vector& args, bool use_tuple_arg, xla::ComputationBuilder* builder, - XlaContext* context, + XlaContext* context, std::vector* arg_cores, std::vector* arg_expressions, std::vector* input_mapping, std::vector* input_shapes) { arg_expressions->resize(args.size()); + *arg_cores = std::vector(args.size(), -1); // Argument numbers of arguments and resources that are to be passed to the // XLA computation as runtime parameters. @@ -288,6 +325,26 @@ Status BuildArguments(const std::vector& args, (*input_mapping)[i] = parameters[i]; } + // Use the _Arg nodes in the graph to resolve core assignments. + for (const Node* n : graph.nodes()) { + if (StringPiece(n->type_string()) != "_Arg") continue; + int index; + TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index)); + TF_RET_CHECK(index >= 0 && index < args.size()) + << "_Arg out of bounds: " << index << " vs " << args.size(); + TF_ASSIGN_OR_RETURN( + auto sharding, + ParseShardingFromDevice(*n, std::numeric_limits::max())); + if (sharding.has_value()) { + TF_RET_CHECK(sharding.value().type() == + xla::OpSharding::Type::OpSharding_Type_MAXIMAL); + const int core = sharding.value().tile_assignment_devices(0); + if ((*arg_cores)[index] == -1 || core < (*arg_cores)[index]) { + (*arg_cores)[index] = core; + } + } + } + // Build parameter handles for non-constant arguments. std::vector arg_handles(parameters.size()); if (use_tuple_arg) { @@ -295,10 +352,18 @@ Status BuildArguments(const std::vector& args, xla::ComputationDataHandle tuple = builder->Parameter(0, tuple_shape, "arg_tuple"); for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->GetTupleElement(tuple, i); } } else { for (std::vector::size_type i = 0; i < parameters.size(); ++i) { + const int core = (*arg_cores)[parameters[i]]; + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); arg_handles[i] = builder->Parameter(i, (*input_shapes)[i], strings::StrCat("arg", i)); } @@ -354,6 +419,7 @@ Status BuildArguments(const std::vector& args, // type of the final output. Status BuildComputation( const std::vector& args, + const std::vector& arg_cores, const std::vector& retvals, const std::vector>& resources, bool return_updated_values_for_all_resources, @@ -384,6 +450,8 @@ Status BuildComputation( for (const XlaResource* resource : arg_resources) { const XlaCompiler::Argument& arg = args[resource->arg_num]; + const int core = arg_cores[resource->arg_num]; + DCHECK_LT(resource->arg_num, arg_cores.size()); bool modified = resource->value.handle() != resource->initial_value.handle(); // TensorArray gradients were modified if their values changed or there are @@ -403,8 +471,21 @@ Status BuildComputation( for (const auto& grad : resource->tensor_array_gradients) { update.tensor_array_gradients_accessed.insert(grad.first); } + + // Request that the value be returned on a specific core. + xla::ScopedShardingAssignment assign_sharding( + builder, core == -1 ? tensorflow::gtl::optional() + : xla::ShardingBuilder::AssignDevice(core)); + xla::ComputationDataHandle handle; TF_RETURN_IF_ERROR(resource->Pack(&handle, builder)); + + // Since we can't change the sharding metadata of as this point, + // create a tuple/get-tuple-element combination so that sharding + // assignment will be placed on this value, which will cause the resource + // update to be returned from the same device that provided the resource. + handle = builder->GetTupleElement(builder->Tuple({handle}), 0); + elems.push_back(handle); } } @@ -465,9 +546,10 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->tuple_arg = options.use_tuple_arg; std::vector arg_expressions; + std::vector arg_cores; TF_RETURN_IF_ERROR(BuildArguments( - args, options.use_tuple_arg, &builder, context, &arg_expressions, - &result->input_mapping, &result->xla_input_shapes)); + *graph, args, options.use_tuple_arg, &builder, context, &arg_cores, + &arg_expressions, &result->input_mapping, &result->xla_input_shapes)); context->set_args(std::move(arg_expressions)); TF_RETURN_IF_ERROR(ExecuteGraph(context, std::move(graph), device_, @@ -477,7 +559,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, int num_computation_outputs; result->computation = std::make_shared(); TF_RETURN_IF_ERROR(BuildComputation( - args, context->retvals(), context->resources(), + args, arg_cores, context->retvals(), context->resources(), options.return_updated_values_for_all_resources, &builder, result->computation.get(), &num_computation_outputs, &num_nonconst_outputs, &result->resource_updates)); @@ -485,7 +567,7 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, result->requires_runtime_context = context->has_context_parameter(); // Tuple arguments and runtime context parameters are incompatible. - CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); + TF_RET_CHECK(!(options.use_tuple_arg && result->requires_runtime_context)); VLOG(2) << "Outputs: total: " << context->retvals().size() << " nonconstant: " << num_nonconst_outputs; @@ -522,7 +604,8 @@ Status XlaCompiler::CompileGraph(const XlaCompiler::CompileOptions& options, i < context->retvals().size(); ++i) { const XlaExpression& retval = context->retvals()[i]; if (!retval.has_constant_value()) { - CHECK_LT(computation_output, num_computation_outputs); + TF_RET_CHECK(computation_output < num_computation_outputs) + << "Computation has more outputs than expected"; OutputDescription& output = result->outputs[i]; output.is_constant = false; TF_RETURN_IF_ERROR(XLAShapeToTensorShape( diff --git a/tensorflow/compiler/tf2xla/xla_compiler.h b/tensorflow/compiler/tf2xla/xla_compiler.h index 35159dbad4117895908584ad48878e2a989b9f40..4d40ca5825a0c864c63826c901169607d5080c09 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.h +++ b/tensorflow/compiler/tf2xla/xla_compiler.h @@ -29,7 +29,6 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { - // The XlaCompiler class is responsible for compilation of a self-contained // subgraph of a TensorFlow computation using the XLA linear algebra runtime. // It does a symbolic execution of the graph starting from specific input @@ -136,6 +135,27 @@ class XlaCompiler { bool operator==(const Argument& other) const; }; + // Options pertaining to an individual call to CompileGraph() or + // CompileFunction(). + struct CompileOptions { + // If `use_tuple_arg` is true, a single tuple parameter will be used for all + // arguments; if false, each argument gets its own parameter. + bool use_tuple_arg = false; + + // If 'return_updated_values_for_all_resources' is true, then updated + // values of all resource arguments will be included in the + // 'resource_updates' of the computation, even if the resource was not + // modified by the computation. Used when compiling loop bodies to ensure + // the input and output signatures match. + bool return_updated_values_for_all_resources = false; + + // If 'resolve_compile_time_constants' is true, then outputs of a + // computation that are known to be compile-time constants will be returned + // as Tensors at compile-time, rather than as run-time outputs of the + // computation. + bool resolve_compile_time_constants = true; + }; + struct OutputDescription { // Type and shape of the output. DataType type; @@ -230,43 +250,12 @@ class XlaCompiler { }; explicit XlaCompiler(Options options); - ~XlaCompiler(); - - // Options pertaining to an individual call to CompileGraph() or - // CompileFunction(). - struct CompileOptions { - // If `use_tuple_arg` is true, a single tuple parameter will be used for all - // arguments; if false, each argument gets its own parameter. - bool use_tuple_arg = false; - - // If 'return_updated_values_for_all_resources' is true, then updated - // values of all resource resources arguments will be included in the - // 'resource_updates' of the computation, even if the resource was not - // modified by the computation. Used when compiling loop bodies to ensure - // the input and output signatures match. - bool return_updated_values_for_all_resources = false; - // If 'resolve_compile_time_constants' is true, then outputs of a - // computation that are known to be compile-time constants will be returned - // as Tensors at compile-time, rather than as run-time outputs of the - // computation. - bool resolve_compile_time_constants = true; - }; + ~XlaCompiler(); - // Compiles a Tensorflow function `fn_name_attrs` into an XLA computation. - // `args` describes the arguments to the function, each of which must either - // be a runtime-parameter to the XLA computation, a compile-time constant, or - // a resource variable. Writes the compiled output to `result`. - // - // The generated XLA computation returns a tuple containing only the - // non-constant outputs as a function of the input arguments. Constant - // arguments are returned as host memory tensors in the output list and are - // not included in the XLA computation's outputs. The XLA computation is - // null if there are no data-dependent outputs and no side effects. Status CompileFunction(const CompileOptions& options, const NameAttrList& fn_name_attrs, - const std::vector& args, - CompilationResult* result); + std::vector args, CompilationResult* result); // Compiles a tensorflow::Graph into an xla::Computation. // Similar to CompileFunction, but takes a Graph as input rather than a @@ -276,10 +265,17 @@ class XlaCompiler { const std::vector& args, CompilationResult* result); + Status PrepareArguments(xla::ComputationBuilder* builder, NameAttrList func, + const std::vector& types, + const std::vector& shapes, + const std::vector& expressions, + std::vector* args); + // Retrieves the channel handle associated with `key`. Allocates // a new channel handle if none exists. - // Channel handles can be used to communicate between different computations. - // Computations that communicate should be compiled with the same XlaCompiler. + // Channel handles can be used to communicate between different + // computations. Computations that communicate should be compiled with the + // same XlaCompiler. Status GetChannelHandle(const string& key, xla::ChannelHandle* channel); const Options& options() const { return options_; } @@ -287,6 +283,18 @@ class XlaCompiler { FunctionLibraryRuntime* flib_runtime() const { return flib_runtime_; } private: + // Sets the function body `fbody` to the one registered as `function`. + Status FindFunctionBody(const NameAttrList& function, + const FunctionBody** fbody); + + // Returns the optimized graph object in this function body. + std::unique_ptr GetGraph(const FunctionBody* fbody); + + // Graph compiler needs to know how to get an optimized graph from a function + // body. + friend class GraphCompiler; + friend class XlaCompilerTest; + Options options_; // Status set to non-OK in the constructor if initialization fails. diff --git a/tensorflow/compiler/tf2xla/xla_compiler_test.cc b/tensorflow/compiler/tf2xla/xla_compiler_test.cc index 531725a62335fc30086de2fe381591eb7d0976d0..93aae8485d157cd4afbf804d695d5c0ab8d7946c 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler_test.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler_test.cc @@ -27,6 +27,8 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/core/common_runtime/function.h" #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function_testlib.h" #include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/tensor_testutil.h" #include "tensorflow/core/graph/graph.h" @@ -36,6 +38,37 @@ limitations under the License. #include "tensorflow/core/public/version.h" namespace tensorflow { + +class XlaCompilerTest : public ::testing::Test { + protected: + XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} + + void SetUp() override { + client_ = xla::ClientLibrary::LocalClientOrDie(); + + XlaOpRegistry::RegisterCompilationKernels(); + + FunctionDefLibrary flib; + flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); + } + + XlaCompiler::Options DefaultOptions() { + XlaCompiler::Options options; + options.device_type = &cpu_device_type_; + options.client = client_; + options.flib_def = flib_def_.get(); + return options; + } + + FunctionLibraryDefinition* LocalFlibDef(XlaCompiler* compiler) { + return compiler->local_flib_def_.get(); + } + + DeviceType cpu_device_type_; + xla::Client* client_; + std::unique_ptr flib_def_; +}; + namespace { // Helper class to test the ability to pass resources through to XLA @@ -63,6 +96,7 @@ class DummyReadResourceOp : public XlaOpKernel { dummy->Unref(); ctx->SetOutput(0, ctx->Input(0)); + ctx->SetOutput(1, ctx->Input(0)); } }; @@ -80,22 +114,25 @@ class DummyReadResourceCC { if (!scope.ok()) return; scope.UpdateStatus(scope.DoShapeInference(ret)); if (!scope.ok()) return; - this->output_ = Output(ret, 0); + this->output1_ = Output(ret, 0); + this->output2_ = Output(ret, 1); } - Node* node() const { return output_.node(); } - Output output_; + Output output1_; + Output output2_; }; REGISTER_OP("DummyReadResource") .Input("input: int32") - .Output("output: int32") + .Output("output1: int32") + .Output("output2: int32") .SetShapeFn(shape_inference::UnknownShape) .Doc(R"doc( A dummy Op. input: dummy input. -output: dummy output. +output1: dummy output. +output2: dummy output. )doc"); REGISTER_XLA_OP(Name("DummyReadResource"), DummyReadResourceOp); @@ -125,31 +162,6 @@ REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_CPU_XLA_JIT), REGISTER_XLA_OP(Name("DummyDuplicateOp").Device(DEVICE_GPU_XLA_JIT), DummyDuplicateOp); -class XlaCompilerTest : public ::testing::Test { - protected: - XlaCompilerTest() : cpu_device_type_(DEVICE_CPU_XLA_JIT) {} - - void SetUp() override { - client_ = xla::ClientLibrary::LocalClientOrDie(); - - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def_.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - } - - XlaCompiler::Options DefaultOptions() { - XlaCompiler::Options options; - options.device_type = &cpu_device_type_; - options.client = client_; - options.flib_def = flib_def_.get(); - return options; - } - - DeviceType cpu_device_type_; - xla::Client* client_; - std::unique_ptr flib_def_; -}; // Tests compilation and execution of an empty graph. TEST_F(XlaCompilerTest, EmptyReturnValues) { @@ -316,7 +328,8 @@ TEST_F(XlaCompilerTest, ResourceManager) { Scope scope = Scope::NewRootScope().ExitOnError(); auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); auto b = DummyReadResourceCC(scope.WithOpName("B"), a); - auto c = ops::_Retval(scope.WithOpName("C"), b.output_, 0); + auto c = ops::Add(scope.WithOpName("C"), b.output2_, b.output1_); + auto d = ops::_Retval(scope.WithOpName("D"), c, 0); std::unique_ptr graph(new Graph(OpRegistry::Global())); TF_ASSERT_OK(scope.ToGraph(graph.get())); @@ -349,6 +362,58 @@ TEST_F(XlaCompilerTest, ResourceManager) { resource->Unref(); } +// Tests compilation and execution of a graph that adds two tensors. +TEST_F(XlaCompilerTest, DeterministicCompilation) { + // Builds a graph that contains a node with two output edges. The compiler + // should always traverse them in the same order. + const int64 test_count = 2; + + std::vector results(test_count); + + for (int64 i = 0; i < test_count; ++i) { + Scope scope = Scope::NewRootScope().ExitOnError(); + auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); + auto b = ops::Neg(scope.WithOpName("B"), a); + auto c = ops::Neg(scope.WithOpName("C"), a); + auto d = ops::Add(scope.WithOpName("D"), b, c); + auto e = ops::_Retval(scope.WithOpName("E"), d, 0); + std::unique_ptr graph(new Graph(OpRegistry::Global())); + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args(1); + args[0].kind = XlaCompiler::Argument::kParameter; + args[0].type = DT_INT32; + args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); + + // Compiles the graph. + auto options = DefaultOptions(); + XlaCompiler compiler(options); + + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "dummy", + std::move(graph), args, &results[i])); + } + + for (int64 i = 1; i < test_count; ++i) { + auto m1 = + results[i - 1].computation->Snapshot().ValueOrDie()->entry().requests(); + auto m2 = + results[i].computation->Snapshot().ValueOrDie()->entry().requests(); + // Check if every entry is the same. + for (auto& entry1 : m1) { + int64 key = entry1.first; + auto value1 = entry1.second; + auto entry2 = m2.find(key); + auto value2 = entry2->second; + EXPECT_TRUE(entry2 != m2.end()); + string str1, str2; + value1.AppendToString(&str1); + value2.AppendToString(&str2); + EXPECT_EQ(str1, str2); + } + } +} + // Tests a computation that receives a TensorArray resource as input and // updates it. TEST_F(XlaCompilerTest, CanPassTensorArraysToAndFromComputation) { @@ -489,5 +554,104 @@ TEST_F(XlaCompilerTest, NewTensorArrayGradientsAreComputationOutputs) { EXPECT_EQ(1, result.resource_updates.size()); } +// Tests CompileFunction with undefined function fails. +TEST_F(XlaCompilerTest, UndefinedFunctionFails) { + XlaCompiler compiler(DefaultOptions()); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("Function_NotDefined_"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + EXPECT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); +} + +FunctionDef FillFn() { + return FunctionDefHelper::Define( + // Name + "FillFn", + // Args + {"x: T", "dims: int32"}, + // Return values + {"y: T"}, + // Attr def + {"T: {float, double, int32, int64}"}, + // Nodes + {{{"y"}, "Fill", {"dims", "x"}, {{"T", "$T"}}}}); +} + +TEST_F(XlaCompilerTest, FunctionCallWithConstants) { + // Certain operations in a function, "Fill" for example, requires the + // operator's argument to be a compile-time constant instead of a parameter. + // This testcase tests if XlaCompiler can handle such operators inside + // function calls. + XlaCompiler compiler(DefaultOptions()); + + FunctionDefLibrary flib; + *flib.add_function() = FillFn(); + + TF_ASSERT_OK(flib_def_->AddFunctionDef(FillFn())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + + Scope scope = Scope::NewRootScope().ExitOnError(); + auto value = ops::Const(scope.WithOpName("value"), 1, {}); + auto shape = ops::Const(scope.WithOpName("shape"), {5}, {1}); + TF_EXPECT_OK(scope.graph()->AddFunctionLibrary(flib)); + + NodeDef def; + TF_ASSERT_OK(NodeDefBuilder("fill", "FillFn", flib_def_.get()) + .Input(value.name(), 0, DT_INT32) + .Input(shape.name(), 1, DT_INT32) + .Finalize(&def)); + Status status; + Node* fill = scope.graph()->AddNode(def, &status); + TF_ASSERT_OK(status); + TF_ASSERT_OK(scope.DoShapeInference(fill)); + scope.graph()->AddEdge(value.node(), 0, fill, 0); + scope.graph()->AddEdge(shape.node(), 0, fill, 1); + + auto retval = ops::_Retval(scope.WithOpName("retval"), Output(fill), 0); + + TF_ASSERT_OK(scope.ToGraph(graph.get())); + + // Builds a description of the argument. + std::vector args; + + XlaCompiler::CompilationResult result; + TF_ASSERT_OK(compiler.CompileGraph(XlaCompiler::CompileOptions(), "fill", + std::move(graph), args, &result)); +} + +// Tests CompileFunction with a local function lookup failing, fails with +// informative error about both lookups. +TEST_F(XlaCompilerTest, LocalFunctionWithWrongArgumentsFail) { + XlaCompiler compiler(DefaultOptions()); + + auto local_flib_def = LocalFlibDef(&compiler); + TF_ASSERT_OK(local_flib_def->AddFunctionDef(test::function::XTimesTwo())); + + std::unique_ptr graph(new Graph(OpRegistry::Global())); + XlaCompiler::CompilationResult result; + NameAttrList name_attr; + name_attr.set_name("XTimesTwo"); + Status status = + compiler.CompileFunction(XlaCompiler::CompileOptions(), name_attr, + /*args=*/{}, &result); + + ASSERT_FALSE(status.ok()); + // Flib lookup failure. + EXPECT_TRUE(StringPiece(status.error_message()).contains("is not defined.")) + << status.error_message(); + // Local flib lookup failure. + EXPECT_TRUE( + StringPiece(status.error_message()).contains("Attr T is not found")) + << status.error_message(); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc index 2366c02dd2b0f22d3cbee929f31bdb0185bfabbc..1df6173275a95bca66f64b3f6df2db9c7a03580b 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.cc +++ b/tensorflow/compiler/tf2xla/xla_helpers.cc @@ -16,9 +16,11 @@ limitations under the License. // This file defines helper routines for Tla JIT compilation. #include "tensorflow/compiler/tf2xla/xla_helpers.h" + #include "tensorflow/compiler/tf2xla/literal_util.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_context.h" +#include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/framework/tensor.h" @@ -26,6 +28,67 @@ limitations under the License. namespace tensorflow { +namespace { + +Status ArgMinMax(xla::ComputationBuilder* builder, XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, bool is_min, + xla::ComputationDataHandle* argminmax) { + xla::ComputationDataHandle init_value; + const xla::Computation* reducer; + if (is_min) { + init_value = XlaHelpers::MaxValue(builder, input_type); + reducer = ctx->GetOrCreateMin(input_type); + } else { + init_value = XlaHelpers::MinValue(builder, input_type); + reducer = ctx->GetOrCreateMax(input_type); + } + + xla::PrimitiveType xla_output_type; + TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(output_type, &xla_output_type)); + + xla::ComputationDataHandle input_max = builder->Reduce( + input, init_value, *reducer, /*dimensions_to_reduce=*/{axis}); + std::vector broadcast_dims(input_shape.dims() - 1); + std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0); + std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1); + // Compute a mask that has 1s for elements equal to the maximum. + xla::ComputationDataHandle partial_mask = builder->ConvertElementType( + builder->Eq(input, input_max, broadcast_dims), xla_output_type); + + // In order to make identity elements for a bitwise And, we: + // Left shift the 1 to the leftmost bit, yielding 0x10...0 + // Arithmetic right shift the 1 back to the rightmost bit, yielding + // 0xFF...F + int32 bits_in_type = + xla::ShapeUtil::ByteSizeOfPrimitiveType(xla_output_type) * 8 - 1; + xla::ComputationDataHandle shift_amount = + XlaHelpers::IntegerLiteral(builder, output_type, bits_in_type); + xla::ComputationDataHandle full_mask = builder->ShiftRightArithmetic( + builder->ShiftLeft(partial_mask, shift_amount), shift_amount); + + // And with the vector [0, 1, 2, ...] to convert each 0xFF...F into its + // index. + xla::ComputationDataHandle iota; + + const int64 axis_size = input_shape.dim_size(axis); + TF_RETURN_IF_ERROR(XlaHelpers::Iota(builder, output_type, axis_size, &iota)); + xla::ComputationDataHandle product = + builder->And(full_mask, iota, /*broadcast_dimensions=*/{axis}); + + // If there are multiple maximum elements, choose the one with the highest + // index. + xla::ComputationDataHandle output = + builder->Reduce(product, XlaHelpers::MinValue(builder, output_type), + *ctx->GetOrCreateMax(output_type), + /*dimensions_to_reduce=*/{axis}); + *argminmax = output; + return Status::OK(); +} + +} // namespace + xla::ComputationDataHandle XlaHelpers::MinValue(xla::ComputationBuilder* b, DataType data_type) { xla::PrimitiveType type; @@ -54,6 +117,19 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b, return b->ConstantLiteral(xla::Literal::One(type)); } +xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b, + DataType data_type) { + switch (data_type) { + case DT_FLOAT: + return b->ConstantR0(std::numeric_limits::epsilon()); + case DT_DOUBLE: + return b->ConstantR0(std::numeric_limits::epsilon()); + default: + LOG(FATAL) << "Unsupported type in XlaHelpers::Epsilon: " + << DataTypeString(data_type); + } +} + xla::ComputationDataHandle XlaHelpers::IntegerLiteral( xla::ComputationBuilder* b, DataType data_type, int64 value) { xla::Literal literal; @@ -84,6 +160,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral( case xla::F64: literal = *xla::Literal::CreateR0(value); break; + case xla::C64: + literal = *xla::Literal::CreateR0(value); + break; case xla::PRED: LOG(FATAL) << "pred element type is not integral"; case xla::S16: @@ -119,6 +198,9 @@ xla::ComputationDataHandle XlaHelpers::FloatLiteral(xla::ComputationBuilder* b, case xla::F64: return b->ConstantR0(value); break; + case xla::C64: + return b->ConstantR0(value); + break; default: LOG(FATAL) << "unhandled element type " << type; } @@ -155,6 +237,50 @@ static Tensor MakeLinspaceTensor(const TensorShape& shape, int64 depth) { return linspace; } +Status XlaHelpers::ArgMax(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmax) { + return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, + axis, /*is_min=*/false, argmax); +} + +Status XlaHelpers::ArgMin(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmin) { + return ArgMinMax(builder, ctx, input, input_shape, input_type, output_type, + axis, /*is_min=*/true, argmin); +} + +Status XlaHelpers::Iota(xla::ComputationBuilder* builder, DataType dtype, + int64 size, xla::ComputationDataHandle* iota) { + TensorShape linspace_shape({size}); + Tensor linspace; + switch (dtype) { + case DT_UINT8: + linspace = MakeLinspaceTensor(linspace_shape, size); + break; + case DT_INT32: + linspace = MakeLinspaceTensor(linspace_shape, size); + break; + case DT_INT64: + linspace = MakeLinspaceTensor(linspace_shape, size); + break; + default: + return errors::InvalidArgument("Invalid argument type ", + DataTypeString(dtype)); + } + xla::Literal linspace_literal; + TF_RETURN_IF_ERROR(HostTensorToLiteral(linspace, &linspace_literal)); + *iota = builder->ConstantLiteral(linspace_literal); + return Status::OK(); +} + Status XlaHelpers::OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, diff --git a/tensorflow/compiler/tf2xla/xla_helpers.h b/tensorflow/compiler/tf2xla/xla_helpers.h index f79a12cf28f7745e316a8b6c06eb72a0a10bef75..2a027db4c839c917f3a7acd27184792d157356bf 100644 --- a/tensorflow/compiler/tf2xla/xla_helpers.h +++ b/tensorflow/compiler/tf2xla/xla_helpers.h @@ -48,6 +48,11 @@ class XlaHelpers { static xla::ComputationDataHandle One(xla::ComputationBuilder* b, DataType data_type); + // Returns the machine epsilon for floating-point type `data_type`, i.e., + // the difference between 1.0 and the next representable value. + static xla::ComputationDataHandle Epsilon(xla::ComputationBuilder* b, + DataType data_type); + // Returns a handle representing the given value of an integer scalar // element of data_type. // Note that unlike One and Zero, does not work on boolean types. @@ -67,10 +72,35 @@ class XlaHelpers { gtl::ArraySlice shape, xla::Literal* output); + // Sets `argmax` to the argmax of `input` along `axis`. `input_shape` and + // `input_dtype` are the shape and dtype of `input` respectively, and + // `output_type` is the dtype to use for `argmax`. + static Status ArgMax(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmax); + + // Sets `argmin` to the argmin of `input` along `axis`. `input_shape` and + // `input_dtype` are the shape and dtype of `input` respectively, and + // `output_type` is the dtype to use for `argmin`. + static Status ArgMin(xla::ComputationBuilder* builder, + XlaOpKernelContext* ctx, + const xla::ComputationDataHandle& input, + const TensorShape& input_shape, DataType input_type, + DataType output_type, int axis, + xla::ComputationDataHandle* argmin); + + // Sets *iota to a rank 1 tensor with values [0, 1, 2, ...] of `dtype`. + static Status Iota(xla::ComputationBuilder* builder, DataType dtype, + int64 size, xla::ComputationDataHandle* iota); + // Converts `indices` into a one-hot representation. `depth` is the size // of the new axis to add. `axis` is the position at which to add the new - // axis. `indices_shape` is the shape of `indices`. `on_value` and `off_value` - // represent the values to use for the on and off positions, respectively. + // axis. `indices_shape` is the shape of `indices`. `on_value` and + // `off_value` represent the values to use for the on and off positions, + // respectively. static Status OneHot(xla::ComputationBuilder* builder, int64 depth, int axis, DataType index_type, const TensorShape& indices_shape, const xla::ComputationDataHandle& indices, diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..1dd454ea8d57e21526e5bcde0c8efc5514983b93 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -0,0 +1,217 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include +#include + +#include "tensorflow/compiler/tf2xla/tf2xla.h" +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/client_library.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_executable.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +namespace { + +// Returns a vector of positional argument buffer sizes. +xla::StatusOr> ComputeArgSizes( + const xla::ProgramShape& program_shape, bool requires_runtime_context) { + std::vector arg_sizes; + const size_t num_args = program_shape.parameters_size(); + arg_sizes.reserve(num_args); + for (int i = 0; i < num_args; ++i) { + const xla::Shape& arg_shape = program_shape.parameters(i); + if (i == num_args - 1 && requires_runtime_context) { + // If the compiled function needs an XlaLocalRuntimeContext* arg, it's + // always last, and must be represented as an opaque type. + const xla::PrimitiveType type = arg_shape.element_type(); + if (type != xla::OPAQUE) { + return errors::InvalidArgument( + "expected final context arg to be opaque, but got type: ", + xla::PrimitiveType_Name(type), ", from program shape: ", + xla::ShapeUtil::HumanString(program_shape)); + } + arg_sizes.push_back(-1); + } else { + constexpr size_t kPointerSize = sizeof(void*); + arg_sizes.push_back(xla::ShapeUtil::ByteSizeOf(arg_shape, kPointerSize)); + } + } + return std::move(arg_sizes); +} + +// Returns a vector of positional temporary buffer sizes. +xla::StatusOr> ComputeTempSizes( + const xla::BufferAssignment& buffer_assignment) { + const std::vector& allocations = + buffer_assignment.Allocations(); + std::vector temp_sizes; + temp_sizes.reserve(allocations.size()); + for (const xla::BufferAllocation& allocation : allocations) { + // Callers don't allocate temporary buffers for parameters. Nor for + // thread-local buffers, which are lowered to alloca. + if (allocation.is_entry_computation_parameter() || + allocation.is_thread_local()) { + temp_sizes.push_back(-1); + } else { + temp_sizes.push_back(allocation.size()); + } + } + return std::move(temp_sizes); +} + +// Returns the index of the result in the temp buffers. +xla::StatusOr ComputeResultIndex( + const xla::BufferAssignment& buffer_assignment) { + TF_ASSIGN_OR_RETURN(const xla::BufferAllocation::Slice result_slice, + buffer_assignment.GetUniqueTopLevelOutputSlice()); + return result_slice.index(); +} + +// Adapt ComputeFunctionType, which includes a final profile_counters arg, to +// RawFunction, which doesn't include that final arg. +// +// TODO(toddw): Change RawFunction and AOT to also pass the final +// profile_counters arg, and remove this adapter. +XlaCompiledCpuFunction::RawFunction RawFunctionAdapter( + xla::cpu::CpuExecutable::ComputeFunctionType compute_function) { + return [compute_function](void* result, + const xla::ExecutableRunOptions* run_options, + const void** args, void** temps) { + return compute_function(result, run_options, args, temps, + /*profile_counters=*/nullptr); + }; +} + +// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold +// the actual strings in nonempty_names, and hold arrays of pointers in +// name_ptrs, terminated by a nullptr entry. +template +void CollectNames(const T& entries, std::vector* nonempty_names, + std::vector* name_ptrs) { + // First collect `nonempty_names`, to ensure the underlying strings won't + // change out from under us. + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + nonempty_names->push_back(name); + } + } + // Now set `name_ptrs` pointing to the strings in `nonempty_names`. + name_ptrs->reserve(entries.size() + 1); // +1 for nullptr array terminator + size_t nonempty_index = 0; + for (const auto& entry : entries) { + const string& name = entry.name(); + if (!name.empty()) { + name_ptrs->push_back(nonempty_names->at(nonempty_index).c_str()); + ++nonempty_index; + } else { + name_ptrs->push_back(""); + } + } + name_ptrs->push_back(nullptr); // array terminator +} + +} // namespace + +/*static*/ xla::StatusOr> +XlaJitCompiledCpuFunction::Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options) { + // Convert the graph_def into an xla::Computation. + TF_ASSIGN_OR_RETURN(xla::LocalClient * client, + xla::ClientLibrary::GetOrCreateLocalClient()); + xla::Computation computation; + bool requires_runtime_context; + TF_RETURN_IF_ERROR(tensorflow::ConvertGraphDefToXla( + graph_def, config, client, &computation, &requires_runtime_context)); + + // Get and verify the program shape. + TF_ASSIGN_OR_RETURN(std::unique_ptr program_shape, + client->GetComputationShape(computation)); + if (program_shape->result().element_type() != xla::TUPLE) { + // The XlaCompiler we use to build the xla computation always generates a + // tuple result, and XlaCompiledCpuFunction relies on this for simpler + // calling semantics. + return errors::Internal( + "XlaJitCompiledCpuFunction requires the XLA result to be a tuple"); + } + // The parameter names are currently meaningless, and redundant with the rest + // of our metadata, so clear them out to avoid confusion and save space. + program_shape->clear_parameter_names(); + + // Compute arg shapes, needed to compile the executable. + std::vector arg_shapes; + arg_shapes.reserve(program_shape->parameters_size()); + for (int i = 0; i < program_shape->parameters_size(); ++i) { + arg_shapes.push_back(&program_shape->parameters(i)); + } + + // Compile the executable. The static_cast to the CpuExecutable subclass is + // necessary since the raw function and buffer assignments are only available + // there. + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + client->Compile(computation, arg_shapes, build_options)); + const xla::cpu::CpuExecutable* cpu_executable = + static_cast(executable->executable()); + XlaCompiledCpuFunction::RawFunction raw_function = + RawFunctionAdapter(cpu_executable->compute_function()); + const xla::BufferAssignment& buffer_assignment = + cpu_executable->buffer_assignment(); + + // Compute buffer sizes and the result index, needed to run the raw function. + TF_ASSIGN_OR_RETURN( + std::vector arg_sizes, + ComputeArgSizes(*program_shape, requires_runtime_context)); + TF_ASSIGN_OR_RETURN(std::vector temp_sizes, + ComputeTempSizes(buffer_assignment)); + TF_ASSIGN_OR_RETURN(size_t result_index, + ComputeResultIndex(buffer_assignment)); + + std::unique_ptr jit_unique_ptr( + new XlaJitCompiledCpuFunction); + XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); + jit->executable_ = std::move(executable); + jit->arg_sizes_ = std::move(arg_sizes); + jit->temp_sizes_ = std::move(temp_sizes); + jit->program_shape_ = std::move(program_shape); + jit->static_data_.raw_function = std::move(raw_function); + jit->static_data_.arg_sizes = jit->arg_sizes_.data(); + jit->static_data_.num_args = jit->arg_sizes_.size(); + jit->static_data_.temp_sizes = jit->temp_sizes_.data(); + jit->static_data_.num_temps = jit->temp_sizes_.size(); + jit->static_data_.result_index = result_index; + jit->static_data_.requires_runtime_context = requires_runtime_context; + // Optional metadata is collected and set below. + CollectNames(config.feed(), &jit->nonempty_arg_names_, &jit->arg_names_); + CollectNames(config.fetch(), &jit->nonempty_result_names_, + &jit->result_names_); + jit->static_data_.arg_names = jit->arg_names_.data(); + jit->static_data_.result_names = jit->result_names_.data(); + jit->static_data_.program_shape = jit->program_shape_.get(); + return std::move(jit_unique_ptr); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h new file mode 100644 index 0000000000000000000000000000000000000000..af307ae4eff74927242c4650d8a43710e991cc52 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -0,0 +1,87 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ + +#include +#include + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +// Represents the result of JIT compilation by XLA down to a function. This +// class holds the state necessary to create XlaCompiledCpuFunction instances, +// which are used to actually invoke the compiled computation. +// +// XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are +// created from it. It holds state shared by all of the functions, including the +// JIT-compiled function itself, along with buffer sizes and other metadata +// necessary for execution. +class XlaJitCompiledCpuFunction { + public: + // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given + // `config` specifies the portion of the graph to compile, via feeds and + // fetches. Each feed is a positional input argument for the compiled + // function, while each fetch is a positional output argument. + static xla::StatusOr> Compile( + const GraphDef& graph_def, const tf2xla::Config& config, + const xla::ExecutableBuildOptions& build_options); + + XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete; + XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) = + delete; + + // Returns static data used to create an XlaCompiledCpuFunction instance, + // which represents the JIT-compiled function. The static data is unchanging + // across each instance. + const XlaCompiledCpuFunction::StaticData& StaticData() const { + return static_data_; + } + + private: + XlaJitCompiledCpuFunction() {} + + // The executable holds the underlying function. + std::unique_ptr executable_; + + // The static data is backed by the rest of the state in this class. + XlaCompiledCpuFunction::StaticData static_data_; + + // The backing arrays of arg and temp buffer sizes. + std::vector arg_sizes_; + std::vector temp_sizes_; + + // The backing arrays of arg and result names. We hold the actual strings in + // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static + // data to refer to. + std::vector nonempty_arg_names_; + std::vector nonempty_result_names_; + std::vector arg_names_; + std::vector result_names_; + + // The backing data for the program shape. + std::unique_ptr program_shape_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6d49298a6f3e8a726695fafc42f3c5341fe98b5f --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -0,0 +1,147 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" + +#include "tensorflow/compiler/tf2xla/tf2xla.pb.h" +#include "tensorflow/compiler/xla/client/local_client.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +AttrValue TypeAttrValue(DataType type) { + AttrValue attr_value; + SetAttrValue(type, &attr_value); + return attr_value; +} + +GraphDef SumGraph() { + GraphDef graph_def; + NodeDef* x = graph_def.add_node(); + x->set_name("x"); + x->set_op("Placeholder"); + (*x->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* y = graph_def.add_node(); + y->set_name("y"); + y->set_op("Placeholder"); + (*y->mutable_attr())["dtype"] = TypeAttrValue(DT_INT32); + NodeDef* sum = graph_def.add_node(); + sum->set_name("sum"); + sum->set_op("Add"); + sum->add_input("x"); + sum->add_input("y"); + (*sum->mutable_attr())["T"] = TypeAttrValue(DT_INT32); + return graph_def; +} + +tf2xla::Config SumConfig() { + tf2xla::Config config; + tf2xla::Feed* x = config.add_feed(); + x->mutable_id()->set_node_name("x"); + x->set_name("x_name"); + tf2xla::Feed* y = config.add_feed(); + y->mutable_id()->set_node_name("y"); + y->set_name("y_name"); + tf2xla::Fetch* sum = config.add_fetch(); + sum->mutable_id()->set_node_name("sum"); + sum->set_name("sum_name"); + return config; +} + +TEST(XlaJitCompiledCpuFunction, Sum) { + GraphDef graph_def = SumGraph(); + tf2xla::Config config = SumConfig(); + + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr jit, + XlaJitCompiledCpuFunction::Compile(graph_def, config, + xla::ExecutableBuildOptions())); + XlaCompiledCpuFunction function(jit->StaticData()); + + // Run the function and check results. + *static_cast(function.arg_data(0)) = 10; + *static_cast(function.arg_data(1)) = 32; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast(function.result_data(0)), 42); + + // Run the function again. + *static_cast(function.arg_data(0)) = 100; + *static_cast(function.arg_data(1)) = 320; + EXPECT_TRUE(function.Run()); + EXPECT_EQ(function.error_msg(), ""); + EXPECT_EQ(*static_cast(function.result_data(0)), 420); + + // Check name to index lookups. + EXPECT_TRUE(function.HasNameIndices()); + + EXPECT_EQ(function.LookupArgIndex("x_name"), 0); + EXPECT_EQ(function.LookupArgIndex("y_name"), 1); + EXPECT_EQ(function.LookupArgIndex(""), -1); + EXPECT_EQ(function.LookupArgIndex("x"), -1); + EXPECT_EQ(function.LookupArgIndex("y"), -1); + EXPECT_EQ(function.LookupArgIndex("sum"), -1); + EXPECT_EQ(function.LookupArgIndex("sum_name"), -1); + + EXPECT_EQ(function.LookupResultIndex("sum_name"), 0); + EXPECT_EQ(function.LookupResultIndex(""), -1); + EXPECT_EQ(function.LookupResultIndex("x"), -1); + EXPECT_EQ(function.LookupResultIndex("y"), -1); + EXPECT_EQ(function.LookupResultIndex("sum"), -1); + EXPECT_EQ(function.LookupResultIndex("x_name"), -1); + EXPECT_EQ(function.LookupResultIndex("y_name"), -1); + + // Check program shape. + using xla::ShapeUtil; + const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); + const xla::ProgramShape* program_shape = function.ProgramShape(); + ASSERT_TRUE(program_shape != nullptr); + ASSERT_EQ(program_shape->parameters_size(), 2); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(0), s32)); + EXPECT_TRUE(ShapeUtil::Compatible(program_shape->parameters(1), s32)); + + const xla::Shape& result = program_shape->result(); + ASSERT_EQ(result.element_type(), xla::TUPLE); + ASSERT_EQ(ShapeUtil::TupleElementCount(result), 1); + const xla::Shape& result0 = ShapeUtil::GetTupleElementShape(result, 0); + EXPECT_TRUE(ShapeUtil::Compatible(result0, s32)); +} + +// Test when a graph compilation terminates early, resources are properly +// reclaimed. +TEST(XlaJitCompiledCpuFunction, SumWithJunkAttr) { + GraphDef graph_def = SumGraph(); + + (*graph_def.mutable_node(2)->mutable_attr())["junk"] = + TypeAttrValue(DT_INT32); + + tf2xla::Config config = SumConfig(); + EXPECT_FALSE(XlaJitCompiledCpuFunction::Compile(graph_def, config, + xla::ExecutableBuildOptions()) + .ok()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index 2cf3d4c1f2563b995d5cd84dc380928552b20f00..02318cf7fa1d4edc12507f6b4d66a8e897cbe100 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -223,7 +223,8 @@ void XlaOpRegistry::RegisterCompilationKernels() { } std::vector XlaOpRegistry::DeviceKernels( - const string& compilation_device_name) { + const string& compilation_device_name, + bool include_compilation_only_kernels) { std::vector kernels; XlaOpRegistry& registry = Instance(); mutex_lock lock(registry.mutex_); @@ -236,7 +237,8 @@ std::vector XlaOpRegistry::DeviceKernels( // The test in IsCompatible ensures that if there are multiple matching // registrations for this op name, they all have the same value of // compilation_only, so only the first match needs to be tested. - if (!op_iter->second->compilation_only) { + if (include_compilation_only_kernels || + !op_iter->second->compilation_only) { kernels.push_back(k.get()); } } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index d74203c82a1932f0b064fea5e2451a10bf222def..6aee8c91cc01b4382ef867fa8e438eede008ac73 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/core/common_runtime/device_factory.h" #include "tensorflow/core/common_runtime/local_device.h" @@ -44,17 +45,19 @@ extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" extern const char* const DEVICE_XLA_CPU; extern const char* const DEVICE_XLA_GPU; -constexpr std::array kIntTypes = {{DT_INT32, DT_INT64}}; constexpr std::array kFloatTypes = { {DT_HALF, DT_FLOAT, DT_DOUBLE}}; -constexpr std::array kNumericTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE}}; +constexpr std::array kNumericTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64}}; -constexpr std::array kCpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kCpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL}}; -constexpr std::array kGpuAllTypes = { - {DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, DT_BOOL}}; +constexpr std::array kGpuAllTypes = { + {DT_UINT32, DT_UINT64, DT_INT32, DT_INT64, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. @@ -116,7 +119,8 @@ class XlaOpRegistry { // 'compilation_device_name'. // Does not include kernels registered as CompilationOnly. static std::vector DeviceKernels( - const string& compilation_device_name); + const string& compilation_device_name, + bool include_compilation_only_kernels); private: friend class XlaBackendRegistrar; diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 25787ececc6505f4e0dab1eace4b3e51285cf932..660f419e464936b01a3644e69c2f056f998140f5 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -7,7 +7,6 @@ package_group( packages = [ "//tensorflow/compiler/...", "//tensorflow/contrib/tpu/...", - "//tensorflow/contrib/xla_tf_graph/...", ], ) @@ -163,6 +162,7 @@ cc_library( name = "util", srcs = ["util.cc"], hdrs = [ + "iterator_util.h", "map_util.h", "ptr_util.h", "util.h", @@ -170,6 +170,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":status", + ":status_macros", ":types", ":xla_data_proto", "//tensorflow/core:lib", @@ -203,6 +204,16 @@ tf_cc_test( ], ) +tf_cc_test( + name = "iterator_util_test", + srcs = ["iterator_util_test.cc"], + deps = [ + ":test", + ":util", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "shape_util", srcs = [ @@ -324,12 +335,32 @@ cc_library( ], ) +cc_library( + name = "array", + hdrs = ["array.h"], + deps = [ + ":types", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "array_test", + srcs = ["array_test.cc"], + deps = [ + ":array", + ":test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "array2d", srcs = ["array2d.cc"], hdrs = ["array2d.h"], visibility = ["//visibility:public"], deps = [ + ":array", ":types", ":util", "//tensorflow/core:lib", @@ -351,6 +382,7 @@ cc_library( hdrs = ["array3d.h"], visibility = [":friends"], deps = [ + ":array", ":types", "//tensorflow/core:lib", ], @@ -372,6 +404,7 @@ cc_library( hdrs = ["array4d.h"], visibility = [":friends"], deps = [ + ":array", ":array2d", ":types", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/array.h b/tensorflow/compiler/xla/array.h new file mode 100644 index 0000000000000000000000000000000000000000..ba898d1f4e9100df59c6e4b28824895c5ae6c08a --- /dev/null +++ b/tensorflow/compiler/xla/array.h @@ -0,0 +1,342 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_ARRAY_H_ +#define TENSORFLOW_COMPILER_XLA_ARRAY_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/bits.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// General N dimensional array class with arbitrary value type. +template +class Array { + public: + // Creates a new array with the specified dimensions. + explicit Array(tensorflow::gtl::ArraySlice sizes) + : Array(sizes, T()) {} + + // Creates a new array with the specified dimensions and specified value for + // every cell. + Array(tensorflow::gtl::ArraySlice sizes, T value) + : sizes_(sizes.begin(), sizes.end()), values_(new T[num_elements()]) { + Fill(value); + } + + // Creates a 2D array from the given nested initializer list. The outer + // initializer list is the first dimension, the inner is the second dimension. + // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. + Array(std::initializer_list> values) + : Array(ToInt64Vector({values.size(), values.begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + values_[idx] = it2; + ++idx; + } + } + CHECK(idx == num_elements()); + } + + // Creates a 3D array from the given nested initializer list. The outer + // initializer list is the first dimension, and so on. + Array(std::initializer_list>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + values_[idx] = it3; + ++idx; + } + } + } + CHECK(idx == num_elements()); + } + + // Creates a 4D array from the given nested initializer list. The outer + // initializer list is the first dimension, and so on. + Array(std::initializer_list< + std::initializer_list>>> + values) + : Array(ToInt64Vector({values.size(), values.begin()->size(), + values.begin()->begin()->size(), + values.begin()->begin()->begin()->size()})) { + int64 idx = 0; + for (const auto& it1 : values) { + for (const auto& it2 : it1) { + for (const auto& it3 : it2) { + for (const auto& it4 : it3) { + values_[idx] = it4; + ++idx; + } + } + } + } + CHECK(idx == num_elements()); + } + + Array(const Array& other) + : sizes_(other.sizes_), values_(new T[num_elements()]) { + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + } + + Array& operator=(const Array& other) { + sizes_ = other.sizes_; + values_.reset(new T[num_elements()]); + std::copy(&other.values_[0], &other.values_[0] + num_elements(), + &values_[0]); + return *this; + } + + // Fills the array with the specified value. + void Fill(const T& value) { + std::fill(&values_[0], &values_[0] + num_elements(), value); + } + + // Fills the array with sequentially increasing values. + void FillIota(const T& value) { + std::iota(&values_[0], &values_[0] + num_elements(), value); + } + + // Fills the array with the sequence i*multiplier for i=0,1,... + void FillWithMultiples(const T& multiplier) { + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = i * multiplier; + } + } + + // Fills the array with random normal variables with the specified mean. + void FillRandom(const T& value, const double mean = 0.0, + const int seed = 12345) { + std::mt19937 g(seed); + std::normal_distribution distribution(mean, + static_cast(value)); + for (int64 i = 0; i < num_elements(); ++i) { + values_[i] = static_cast(distribution(g)); + } + } + + // Sets all the values in the array to values specified in the container. + template > + void SetValues(const Container& container) { + CHECK_EQ(std::distance(std::begin(container), std::end(container)), + num_elements()); + std::copy(std::begin(container), std::end(container), &values_[0]); + } + + // Invokes a callback with the (indices, value_ptr) for each cell in the + // array. + void Each(std::function, T*)> f) { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index, &values_[i]); + } + } + + // Invokes a callback with the (indices, value) for each cell in the array. + void Each( + std::function, T)> f) const { + std::vector index(sizes_.size()); + for (int64 i = 0; i < num_elements(); ++i, next_index(&index)) { + f(index, values_[i]); + } + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + template + const T& operator()(Dims... dims) const { + // We are using a std::array to avoid having to allocate memory in this + // function for performance reasons. + std::array indexes{{static_cast(dims)...}}; + return values_[calculate_index(indexes)]; + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + template + T& operator()(Dims... dims) { + // We are using a std::array to avoid having to allocate memory in this + // function for performance reasons. + std::array indexes{{static_cast(dims)...}}; + return values_[calculate_index(indexes)]; + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + const T& operator()(tensorflow::gtl::ArraySlice indexes) const { + return values_[calculate_index(indexes)]; + } + + // Returns the value at the cell specified by the indexes. The number of + // arguments have to match with the number of dimensions for the array. + T& operator()(tensorflow::gtl::ArraySlice indexes) { + return values_[calculate_index(indexes)]; + } + + // Low-level accessor for stuff like memcmp, handle with care. Returns pointer + // to the underlying storage of the array (similarly to std::vector::data()). + T* data() const { + // TODO(tberghammer): Get rid of the const_cast. Currently it is needed + // because the Eigen backend needs a non-const pointers even for reading + // from the array. + return const_cast(this)->values_.get(); + } + + // Returns the size of the dimension at the given index. + int64 dim(int64 n) const { + CHECK(n < sizes_.size()); + return sizes_[n]; + } + + // Returns a vector containing the dimensions of the array. + const std::vector& dimensions() const { return sizes_; } + + int64 num_dimensions() const { return sizes_.size(); } + + // Returns the total number of elements in the array. + int64 num_elements() const { + return std::accumulate(sizes_.begin(), sizes_.end(), 1, + std::multiplies()); + } + + const T* begin() const { return &values_[0]; } + T* begin() { return &values_[0]; } + const T* end() const { return &values_[num_elements()]; } + T* end() { return &values_[num_elements()]; } + + bool operator==(const Array& other) const { + if (sizes_.size() != other.sizes_.size()) { + return false; + } + for (int64 i = 0; i < sizes_.size(); ++i) { + if (sizes_[i] != other.sizes_[i]) { + return false; + } + } + for (int64 i = 0; i < num_elements(); ++i) { + if (values_[i] != other.values_[i]) { + return false; + } + } + return true; + } + + bool operator!=(const Array& other) const { return !(*this == other); } + + // Returns a string representation of the array suitable for debugging. + string ToString() const { + std::vector pieces; + std::vector index(sizes_.size()); + do { + // Emit leading spaces and opening square brackets + if (index.back() == 0) { + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + if (i == 0 || index[i - 1] != 0) { + for (int64 j = 0; j < sizes_.size(); ++j) { + pieces.push_back(j < i ? " " : "["); + } + break; + } + } + } + + pieces.push_back( + tensorflow::strings::AlphaNum(values_[calculate_index(index)]) + .data()); + + // Emit comma if it isn't the last element + if (index.back() != sizes_.back() - 1) { + pieces.push_back(", "); + } + + // Emit closing square brackets + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + if (index[i] != sizes_[i] - 1) { + break; + } + pieces.push_back("]"); + if (i != 0 && index[i - 1] != sizes_[i - 1] - 1) { + pieces.push_back(",\n"); + } + } + } while (next_index(&index)); + return tensorflow::str_util::Join(pieces, ""); + } + + private: + // Converts an initializer_list of type U to a vector of type int64. Used by + // the initializer list based constructors to convert the size type into int64 + // to be passed to the size based constructor. + template + static std::vector ToInt64Vector( + const std::initializer_list& data) { + return std::vector(data.begin(), data.end()); + } + + // Returns the linear index from the list of per-dimension indexes. Function + // is templated so can be used with an std::array from operator() to avoid + // memory allocation. + template + int64 calculate_index(const U& indexes) const { + CHECK_EQ(sizes_.size(), indexes.size()); + int64 index = 0; + for (int64 i = 0; i < sizes_.size(); ++i) { + index *= sizes_[i]; + index += indexes[i]; + } + return index; + } + + // Advances the specified set of indexes and returns true if we haven't + // wrapped around (i.e. result isn't {0, 0, ...}). + bool next_index(std::vector* index) const { + CHECK_EQ(index->size(), sizes_.size()); + for (int64 i = sizes_.size() - 1; i >= 0; --i) { + (*index)[i]++; + if ((*index)[i] < sizes_[i]) { + return true; + } + (*index)[i] = 0; + } + return false; + } + + std::vector sizes_; + std::unique_ptr values_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_ARRAY_H_ diff --git a/tensorflow/compiler/xla/array2d.h b/tensorflow/compiler/xla/array2d.h index 2737764cbda87298599d7005c237a2093cbaba4a..bb85fbee9b97fd6b9b0bf7223a9b820989dcbfa7 100644 --- a/tensorflow/compiler/xla/array2d.h +++ b/tensorflow/compiler/xla/array2d.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/core/bits.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,93 +35,30 @@ limitations under the License. namespace xla { -// Simple 2D array structure. -// -// The data layout in major-to-minor order is: n1, n2. template -class Array2D { +class Array2D : public Array { public: - // Creates an empty array. - Array2D() : n1_(0), n2_(0) {} + Array2D() : Array(std::vector{0, 0}) {} - // Creates an array of dimensions n1 x n2, uninitialized values. Array2D(const int64 n1, const int64 n2) - : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { - Fill(T()); - } + : Array(std::vector{n1, n2}) {} - // Creates an array of dimensions n1 x n2, initialized to value. Array2D(const int64 n1, const int64 n2, const T value) - : n1_(n1), n2_(n2), values_(new T[n1 * n2]()) { - Fill(value); - } + : Array({n1, n2}, value) {} // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension; the inner is the second dimension. // For example, {{1, 2, 3}, {4, 5, 6}} results in an array with n1=2 and n2=3. Array2D(std::initializer_list> values) - : Array2D(values.size(), values.begin()->size()) { - int64 n1 = 0; - for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { - int64 n2 = 0; - for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { - (*this)(n1, n2) = *n2_it; - } - } - } + : Array(values) {} - Array2D(const Array2D& other) : Array2D(other.n1(), other.n2()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array2D& operator=(const Array2D& other) { - n1_ = other.n1(); - n2_ = other.n2(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } + Array2D(const Array2D& other) : Array(other) {} - T& operator()(const int64 i1, const int64 i2) { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - return values_[i1 * n2_ + i2]; - } + int64 n1() const { return this->dim(0); } + int64 n2() const { return this->dim(1); } - const T& operator()(const int64 i1, const int64 i2) const { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - return values_[i1 * n2_ + i2]; - } - - // Access to the array's dimensions. height() and width() provide the - // canonical interpretation of the array n1 x n2 having n1 rows of n2 columns - // each (height is number of rows; width is number of columns). - int64 n1() const { return n1_; } - int64 n2() const { return n2_; } - int64 height() const { return n1_; } - int64 width() const { return n2_; } - int64 num_elements() const { return n1_ * n2_; } - - // Low-level accessor for stuff like memcmp, handle with care. Returns pointer - // to the underlying storage of the array (similarly to std::vector::data()). - T* data() const { return const_cast(this)->values_.get(); } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } - - // Applies f to all cells in this array, in row-major order. - void Each(std::function f) { - for (int64 i0 = 0; i0 < n1(); ++i0) { - for (int64 i1 = 0; i1 < n2(); ++i1) { - f(i0, i1, &(*this)(i0, i1)); - } - } - } + int64 height() const { return this->dim(0); } + int64 width() const { return this->dim(1); } // Fills the array with a pattern of values of the form: // @@ -136,55 +74,14 @@ class Array2D { } } - // Fills the array with random normal variables of deviation value. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - // Returns a readable string representation of the array. - string ToString() const { - std::vector pieces = {"["}; - for (int64 row = 0; row < height(); ++row) { - pieces.push_back("["); - for (int64 col = 0; col < width(); ++col) { - pieces.push_back(tensorflow::strings::StrCat((*this)(row, col))); - pieces.push_back(", "); - } - pieces.pop_back(); - pieces.push_back("]"); - pieces.push_back(",\n "); - } - pieces.pop_back(); - pieces.push_back("]"); - return tensorflow::str_util::Join(pieces, ""); - } - - bool operator==(const Array2D& other) const { - if (n1() != other.n1() || n2() != other.n2()) { - return false; - } + // Applies f to all cells in this array, in row-major order. + void Each(std::function f) { for (int64 i0 = 0; i0 < n1(); ++i0) { for (int64 i1 = 0; i1 < n2(); ++i1) { - if ((*this)(i0, i1) != other(i0, i1)) { - return false; - } + f(i0, i1, &(*this)(i0, i1)); } } - return true; } - - bool operator!=(const Array2D& other) const { return !(*this == other); } - - private: - int64 n1_; - int64 n2_; - std::unique_ptr values_; }; // Returns a linspace-populated Array2D in the range [from, to] (inclusive) diff --git a/tensorflow/compiler/xla/array3d.h b/tensorflow/compiler/xla/array3d.h index 124ccd1975b3a9ab047e9bbbfb38921fe7386fe4..e9449f01ad69a5722f53cce09e2884e20a0def5a 100644 --- a/tensorflow/compiler/xla/array3d.h +++ b/tensorflow/compiler/xla/array3d.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" @@ -32,22 +33,16 @@ limitations under the License. namespace xla { // Simple 3D array structure. -// -// The data layout in major-to-minor order is: n1, n2, n3. template -class Array3D { +class Array3D : public Array { public: // Creates an array of dimensions n1 x n2 x n3, uninitialized values. Array3D(const int64 n1, const int64 n2, const int64 n3) - : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { - Fill(T()); - } + : Array(std::vector{n1, n2, n3}) {} // Creates an array of dimensions n1 x n2 x n3, initialized to value. Array3D(const int64 n1, const int64 n2, const int64 n3, const T value) - : n1_(n1), n2_(n2), n3_(n3), values_(new T[n1 * n2 * n3]) { - Fill(value); - } + : Array(std::vector{n1, n2, n3}, value) {} // Creates an array from the given nested initializer list. The outer // initializer list is the first dimension, and so on. @@ -58,84 +53,11 @@ class Array3D { // results in an array with n1=3, n2=4, n3=2. Array3D(std::initializer_list>> values) - : Array3D(values.size(), values.begin()->size(), - values.begin()->begin()->size()) { - int64 n1 = 0; - for (auto n1_it = values.begin(); n1_it != values.end(); ++n1_it, ++n1) { - int64 n2 = 0; - for (auto n2_it = n1_it->begin(); n2_it != n1_it->end(); ++n2_it, ++n2) { - int64 n3 = 0; - for (auto n3_it = n2_it->begin(); n3_it != n2_it->end(); - ++n3_it, ++n3) { - (*this)(n1, n2, n3) = *n3_it; - } - } - } - } + : Array(values) {} - Array3D(const Array3D& other) - : Array3D(other.n1(), other.n2(), other.n3()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array3D& operator=(const Array3D& other) { - n1_ = other.n1(); - n2_ = other.n2(); - n3_ = other.n3(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } - - T& operator()(const int64 i1, const int64 i2, const int64 i3) { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - CHECK_LT(i3, n3_); - return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; - } - - const T& operator()(const int64 i1, const int64 i2, const int64 i3) const { - CHECK_LT(i1, n1_); - CHECK_LT(i2, n2_); - CHECK_LT(i3, n3_); - return values_[i1 * n2_ * n3_ + i2 * n3_ + i3]; - } - - // Access to the array's dimensions. - int64 n1() const { return n1_; } - int64 n2() const { return n2_; } - int64 n3() const { return n3_; } - int64 num_elements() const { return n1_ * n2_ * n3_; } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with sequentially increasing values. - void FillIota(const T& value) { - std::iota(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with random normal values with a mean of 0 and standard - // deviation of value. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - private: - int64 n1_; - int64 n2_; - int64 n3_; - std::unique_ptr values_; + int64 n1() const { return this->dim(0); } + int64 n2() const { return this->dim(1); } + int64 n3() const { return this->dim(2); } }; } // namespace xla diff --git a/tensorflow/compiler/xla/array4d.h b/tensorflow/compiler/xla/array4d.h index 4c7fce1aaf1faf4bd08bca38bc8eb2b47303b575..f8b2b2afe5fed9c465c2a1f39308b7f44311b16a 100644 --- a/tensorflow/compiler/xla/array4d.h +++ b/tensorflow/compiler/xla/array4d.h @@ -26,6 +26,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -53,23 +54,15 @@ namespace xla { // more than one name is given above. See operator() for the exact // calculation of 1d indices from 4d indices. template -class Array4D { +class Array4D : public Array { public: // Creates a 4D array, uninitialized values. Array4D(int64 planes, int64 depth, int64 height, int64 width) - : planes_(planes), - depth_(depth), - height_(height), - width_(width), - values_(new T[planes * depth * height * width]) { - Fill(T()); - } + : Array(std::vector{planes, depth, height, width}) {} // Creates a 4D array, initialized to value. Array4D(int64 planes, int64 depth, int64 height, int64 width, T value) - : Array4D(planes, depth, height, width) { - Fill(value); - } + : Array(std::vector{planes, depth, height, width}, value) {} // Creates a 4D array, filled with values. // @@ -80,144 +73,26 @@ class Array4D { Array4D(int64 planes, int64 depth, int64 height, int64 width, const Container& values) : Array4D(planes, depth, height, width) { - SetValues(values); + this->SetValues(values); } // Construct an Array4D with the given nested initializer list. Array4D(std::initializer_list>>> values) - : Array4D(values.size(), values.begin()->size(), - values.begin()->begin()->size(), - values.begin()->begin()->begin()->size()) { - int64 plane = 0; - for (const auto values_in_plane : values) { - DCHECK_EQ(values_in_plane.size(), depth_); - int64 depth = 0; - for (const auto values_in_depth : values_in_plane) { - DCHECK_EQ(values_in_depth.size(), height_); - int64 height = 0; - for (const auto values_in_height : values_in_depth) { - DCHECK_EQ(values_in_height.size(), width_); - int64 width = 0; - for (const auto element_value : values_in_height) { - (*this)(plane, depth, height, width) = element_value; - ++width; - } - ++height; - } - ++depth; - } - ++plane; - } - } - - Array4D(const Array4D& other) - : Array4D(other.planes(), other.depth(), other.height(), other.width()) { - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - } - - Array4D& operator=(const Array4D& other) { - planes_ = other.planes(); - depth_ = other.depth(); - height_ = other.height(); - width_ = other.width(); - values_.reset(new T[num_elements()]); - std::copy(&other.values_[0], &other.values_[0] + num_elements(), - &values_[0]); - return *this; - } - - T& operator()(int64 plane, int64 depth, int64 height, int64 width) { - CHECK_LT(plane, planes_); - CHECK_LT(depth, depth_); - CHECK_LT(height, height_); - CHECK_LT(width, width_); - return values_[plane * (depth_ * height_ * width_) + - depth * (height_ * width_) + height * (width_) + width]; - } - const T& operator()(int64 plane, int64 depth, int64 height, - int64 width) const { - return const_cast(this)->operator()(plane, depth, height, width); - } - - int64 width() const { return width_; } - int64 height() const { return height_; } - int64 depth() const { return depth_; } - int64 planes() const { return planes_; } + : Array(values) {} // Numerically-named aliases for the various dimensions. This matches the // dimension names used in array3d. - int64 n4() const { return width_; } - int64 n3() const { return height_; } - int64 n2() const { return depth_; } - int64 n1() const { return planes_; } - int64 num_elements() const { return width_ * height_ * depth_ * planes_; } - - // Sets all the values in the array to values. - template > - void SetValues(const Container& container) { - CHECK_EQ(std::distance(std::begin(container), std::end(container)), - num_elements()); - std::copy(std::begin(container), std::end(container), &values_[0]); - } - - // Fills the array with the given value. - void Fill(const T& value) { - std::fill(&values_[0], &values_[0] + num_elements(), value); - } + int64 n4() const { return this->dim(3); } + int64 n3() const { return this->dim(2); } + int64 n2() const { return this->dim(1); } + int64 n1() const { return this->dim(0); } - // Fills the array with iota. - void FillIota(const T& value) { - std::iota(&values_[0], &values_[0] + num_elements(), value); - } - - // Fills the array with random variable with a deviation of value and a mean - // of mean. - void FillRandom(const T& value, const double mean = 0.0, - const int seed = 12345) { - std::mt19937 g(seed); - std::normal_distribution distribution(mean, - static_cast(value)); - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = static_cast(distribution(g)); - } - } - - // Fills values with the sequence i*multiplier for i=0,1,... - void FillWithMultiples(float multiplier) { - for (int64 i = 0; i < num_elements(); ++i) { - values_[i] = i * multiplier; - } - } - - // Invokes a callback with the (indices, value_ptr) for each cell in the 4D - // array. - void Each(std::function, T*)> f) { - for (int64 plane = 0; plane < planes(); ++plane) { - for (int64 depth = 0; depth < this->depth(); ++depth) { - for (int64 height = 0; height < this->height(); ++height) { - for (int64 width = 0; width < this->width(); ++width) { - auto& value = (*this)(plane, depth, height, width); - f({plane, depth, height, width}, &value); - } - } - } - } - } - - // Invokes a callback with the (indices, value) for each cell in the 4D array. - void Each( - std::function, T)> f) const { - // We const_cast to be able to use the common non-const implementation, - // but prevent modification of the data by passing it by-value to the - // caller. - const_cast(this)->Each( - [&f](tensorflow::gtl::ArraySlice indices, T* value) { - f(indices, *value); - }); - } + int64 width() const { return this->dim(3); } + int64 height() const { return this->dim(2); } + int64 depth() const { return this->dim(1); } + int64 planes() const { return this->dim(0); } // Fills all of the {p,z} with the array provided, which specifies {y,x}. void FillWithYX(const Array2D& value) { @@ -267,38 +142,6 @@ class Array4D { } } } - - // Returns a string representation of the 4D array suitable for debugging. - string ToString() const { - std::vector pieces = { - tensorflow::strings::Printf("p=%lld,z=%lld,y=%lld,x=%lld {\n", planes(), - depth(), height(), width())}; - for (int64 plane = 0; plane < planes_; ++plane) { - pieces.push_back(" {\n"); - for (int64 depth = 0; depth < depth_; ++depth) { - pieces.push_back(" {\n"); - for (int64 height = 0; height < height_; ++height) { - pieces.push_back(" {"); - for (int64 width = 0; width < width_; ++width) { - pieces.push_back(tensorflow::strings::StrCat( - (*this)(plane, depth, height, width), ", ")); - } - pieces.push_back("},\n"); - } - pieces.push_back(" },\n"); - } - pieces.push_back(" },\n"); - } - pieces.push_back("}"); - return tensorflow::str_util::Join(pieces, ""); - } - - private: - int64 planes_; - int64 depth_; - int64 height_; - int64 width_; - std::unique_ptr values_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/array_test.cc b/tensorflow/compiler/xla/array_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..093784f541b3bd18f4a1fc1b665cd0d17a892f28 --- /dev/null +++ b/tensorflow/compiler/xla/array_test.cc @@ -0,0 +1,145 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/array.h" + +#include + +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(ArrayTest, UninitializedDimsCtor) { + Array uninit({2, 3}); + EXPECT_EQ(uninit.num_dimensions(), 2); + EXPECT_EQ(uninit.dim(0), 2); + EXPECT_EQ(uninit.dim(1), 3); + EXPECT_EQ(uninit.num_elements(), 6); +} + +TEST(ArrayTest, FillCtor) { + Array fullof7({1, 2, 3}, 7); + + EXPECT_EQ(fullof7.dim(0), 1); + EXPECT_EQ(fullof7.dim(1), 2); + EXPECT_EQ(fullof7.dim(2), 3); + + for (int64 n0 = 0; n0 < fullof7.dim(0); ++n0) { + for (int64 n1 = 0; n1 < fullof7.dim(1); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(2); ++n2) { + EXPECT_EQ(fullof7(n0, n1, n2), 7); + } + } + } +} + +TEST(ArrayTest, InitializerListCtor) { + Array arr({{1, 2, 3}, {4, 5, 6}}); + + EXPECT_EQ(arr.dim(0), 2); + EXPECT_EQ(arr.dim(1), 3); + + EXPECT_EQ(arr(0, 0), 1); + EXPECT_EQ(arr(0, 1), 2); + EXPECT_EQ(arr(0, 2), 3); + EXPECT_EQ(arr(1, 0), 4); + EXPECT_EQ(arr(1, 1), 5); + EXPECT_EQ(arr(1, 2), 6); +} + +TEST(ArrayTest, IndexingReadWrite) { + Array arr({2, 3}); + + EXPECT_EQ(arr(1, 1), 0); + EXPECT_EQ(arr(1, 2), 0); + arr(1, 1) = 51; + arr(1, 2) = 61; + EXPECT_EQ(arr(1, 1), 51); + EXPECT_EQ(arr(1, 2), 61); +} + +TEST(ArrayTest, IndexingReadWriteBool) { + Array arr{{false, true, false}, {false, true, false}}; + + EXPECT_EQ(arr(0, 1), true); + EXPECT_EQ(arr(0, 2), false); + arr(0, 1) = false; + arr(0, 2) = true; + EXPECT_EQ(arr(0, 1), false); + EXPECT_EQ(arr(0, 2), true); +} + +TEST(ArrayTest, Fill) { + Array fullof7({2, 3}, 7); + for (int64 n1 = 0; n1 < fullof7.dim(0); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(1); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 7); + } + } + + fullof7.Fill(11); + for (int64 n1 = 0; n1 < fullof7.dim(0); ++n1) { + for (int64 n2 = 0; n2 < fullof7.dim(1); ++n2) { + EXPECT_EQ(fullof7(n1, n2), 11); + } + } +} + +TEST(ArrayTest, DataPointer) { + Array arr{{1, 2, 3}, {4, 5, 6}}; + EXPECT_EQ(arr.data()[0], 1); +} + +TEST(ArrayTest, Stringification1D) { + Array arr({2}, 1); + const string expected = R"([1, 1])"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Stringification2D) { + Array arr({2, 3}, 7); + const string expected = "[[7, 7, 7],\n [7, 7, 7]]"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Stringification3D) { + Array arr({2, 3, 4}, 5); + const string expected = R"([[[5, 5, 5, 5], + [5, 5, 5, 5], + [5, 5, 5, 5]], + [[5, 5, 5, 5], + [5, 5, 5, 5], + [5, 5, 5, 5]]])"; + EXPECT_EQ(expected, arr.ToString()); +} + +TEST(ArrayTest, Each) { + Array arr({2, 3, 4}); + arr.FillWithMultiples(1); + + int64 each_count = 0, each_sum = 0; + arr.Each([&](tensorflow::gtl::ArraySlice idx, int cell) { + int64 lin_idx = idx[0] * 12 + idx[1] * 4 + idx[2]; + EXPECT_EQ(lin_idx, cell); + each_count++; + each_sum += cell; + }); + EXPECT_EQ(arr.num_elements(), each_count); + EXPECT_EQ(arr.num_elements() * (arr.num_elements() - 1) / 2, each_sum); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD index 2b142d933dbc8c5a7823f9426c423b59425a85bc..f953407a567b91fdf6ae727d6982a2a778c5873e 100644 --- a/tensorflow/compiler/xla/client/BUILD +++ b/tensorflow/compiler/xla/client/BUILD @@ -41,7 +41,9 @@ cc_library( srcs = ["padding.cc"], hdrs = ["padding.h"], deps = [ + "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/core:lib", ], ) @@ -168,6 +170,7 @@ cc_library( ":computation", ":global_data", ":padding", + "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:array3d", "//tensorflow/compiler/xla:array4d", diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc index 387253617e4f37a1561d4659eb796a181f0b5bee..92cd8e729d659c4ff24c156d89f29275848c3cee 100644 --- a/tensorflow/compiler/xla/client/client.cc +++ b/tensorflow/compiler/xla/client/client.cc @@ -206,6 +206,7 @@ StatusOr> Client::Execute( *request.mutable_execution_options() = *execution_options; } for (GlobalData* argument : arguments) { + CHECK(argument != nullptr) << "Argument pointers must not be null."; *request.add_arguments() = argument->handle(); } @@ -241,9 +242,6 @@ StatusOr>> Client::ExecuteParallel( for (GlobalData* argument : computation.arguments) { *single_request.add_arguments() = argument->handle(); } - if (computation.device_handle != nullptr) { - *single_request.mutable_device_handle() = *computation.device_handle; - } *single_request.mutable_execution_options() = computation.execution_options; *request.add_requests() = single_request; } diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h index e72816a6217afd6a827642bbe3aa205409ef5718..a716159f9e74041c4823ad20b46fa94c2d7b9d8c 100644 --- a/tensorflow/compiler/xla/client/client.h +++ b/tensorflow/compiler/xla/client/client.h @@ -45,6 +45,10 @@ class Client { // * If execution_options is not nullptr, these options are passed to the // service to affect how it compiles our computation. (The pointer does not // need to live beyond this call.) + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. // * If execution_profile is not nullptr then the pointed-to ExecutionProfile // will be filled with profile data from the execution. StatusOr> Execute( @@ -54,12 +58,13 @@ class Client { ExecutionProfile* execution_profile = nullptr); // A struct to represent a computation instance to be executed. - // * If device_handle is not nullptr, the computation is executed on a device - // associated with the handle. Otherwise, a device is chosen by the service. + // * If execution_options.device_handles is not empty, the computation is + // executed on the devices associated with the handles by partitioning the + // computation based on the attached sharding attributes. Otherwise, a + // device is chosen by the service. struct ComputationInstance { const Computation& computation; std::vector arguments; - const DeviceHandle* device_handle; ExecutionOptions execution_options; ExecutionProfile* execution_profile; }; diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index 210a4d95b944b1007db7fb72c4cfdc18066cd559..763d94e94c2167f47b3f0777a31815f02791aa9e 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -489,6 +489,16 @@ ComputationDataHandle ComputationBuilder::Collapse( } std::unique_ptr original_shape = shape_or_status.ConsumeValueOrDie(); + VLOG(3) << "original shape: " << ShapeUtil::HumanString(*original_shape); + VLOG(3) << "dims to collapse: " + << tensorflow::str_util::Join(dims_to_collapse, ","); + + if (dims_to_collapse.size() <= 1) { + // Not collapsing anything, trivially we can return the operand versus + // enqueueing a trivial reshape. + return operand; + } + std::vector new_sizes; for (int i = 0; i < ShapeUtil::Rank(*original_shape); ++i) { if (i <= dims_to_collapse.front() || i > dims_to_collapse.back()) { @@ -498,6 +508,9 @@ ComputationDataHandle ComputationBuilder::Collapse( } } + VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",") + << "]"; + return Reshape(operand, new_sizes); } @@ -650,7 +663,7 @@ bool ComputationBuilder::VerifyConvolution( return false; } int num_dims = ShapeUtil::Rank(lhs_shape); - if (num_dims < 3) { + if (num_dims < 2) { NoteError(InvalidArgument( "Convolution expects argument arrays with >= 3 dimensions. " "Got: %s and %s", @@ -900,6 +913,17 @@ ComputationDataHandle ComputationBuilder::CustomCall( return ParseOpResponse(s, &response); } +ComputationDataHandle ComputationBuilder::Complex( + const ComputationDataHandle& real, const ComputationDataHandle& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_COMPLEX, real, imag, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::Conj( + const ComputationDataHandle& operand) { + return Complex(Real(operand), Neg(Imag(operand))); +} + ComputationDataHandle ComputationBuilder::Add( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { @@ -942,21 +966,39 @@ ComputationDataHandle ComputationBuilder::Min( return BinaryOp(BINOP_MIN, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalAnd( +ComputationDataHandle ComputationBuilder::And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_AND, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_AND, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalOr( +ComputationDataHandle ComputationBuilder::Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions) { - return BinaryOp(BINOP_LOGICAL_OR, lhs, rhs, broadcast_dimensions); + return BinaryOp(BINOP_OR, lhs, rhs, broadcast_dimensions); } -ComputationDataHandle ComputationBuilder::LogicalNot( +ComputationDataHandle ComputationBuilder::Not( const ComputationDataHandle& operand) { - return UnaryOp(UNOP_LOGICAL_NOT, operand); + return UnaryOp(UNOP_NOT, operand); +} + +ComputationDataHandle ComputationBuilder::ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_LEFT, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_ARITHMETIC, lhs, rhs, broadcast_dimensions); +} + +ComputationDataHandle ComputationBuilder::ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_SHIFT_RIGHT_LOGICAL, lhs, rhs, broadcast_dimensions); } ComputationDataHandle ComputationBuilder::Abs( @@ -964,6 +1006,12 @@ ComputationDataHandle ComputationBuilder::Abs( return UnaryOp(UNOP_ABS, operand); } +ComputationDataHandle ComputationBuilder::Atan2( + const ComputationDataHandle& y, const ComputationDataHandle& x, + tensorflow::gtl::ArraySlice broadcast_dimensions) { + return BinaryOp(BINOP_ATAN2, y, x, broadcast_dimensions); +} + ComputationDataHandle ComputationBuilder::Exp( const ComputationDataHandle& operand) { return UnaryOp(UNOP_EXP, operand); @@ -1009,6 +1057,16 @@ ComputationDataHandle ComputationBuilder::Tanh( return UnaryOp(UNOP_TANH, operand); } +ComputationDataHandle ComputationBuilder::Real( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_REAL, operand); +} + +ComputationDataHandle ComputationBuilder::Imag( + const ComputationDataHandle& operand) { + return UnaryOp(UNOP_IMAG, operand); +} + ComputationDataHandle ComputationBuilder::IsFinite( const ComputationDataHandle& operand) { return UnaryOp(UNOP_IS_FINITE, operand); @@ -1251,7 +1309,7 @@ Status ComputationBuilder::SetReturnValue( } StatusOr ComputationBuilder::IsConstant( - const ComputationDataHandle& operand) { + const ComputationDataHandle& operand, int64 num_parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1259,6 +1317,7 @@ StatusOr ComputationBuilder::IsConstant( IsConstantRequest request; *request.mutable_computation() = computation_.handle(); *request.mutable_operand() = operand; + request.set_num_parameters(num_parameters); IsConstantResponse response; VLOG(2) << "making IsConstant request"; @@ -1272,7 +1331,8 @@ StatusOr ComputationBuilder::IsConstant( } StatusOr> ComputationBuilder::ComputeConstant( - const ComputationDataHandle& operand, const Layout* output_layout) { + const ComputationDataHandle& operand, const Layout* output_layout, + tensorflow::gtl::ArraySlice parameters) { if (!first_error_.ok()) { return first_error_; } @@ -1283,6 +1343,9 @@ StatusOr> ComputationBuilder::ComputeConstant( if (output_layout != nullptr) { *request.mutable_output_layout() = *output_layout; } + for (const auto& param : parameters) { + *request.add_parameters() = param.ToProto(); + } ComputeConstantResponse response; @@ -1307,6 +1370,7 @@ StatusOr> ComputationBuilder::ComputeConstant( ComputationDataHandle ComputationBuilder::Map( tensorflow::gtl::ArraySlice operands, const Computation& computation, + tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands) { if (!first_error_.ok() || !PrepareComputation().ok()) { return ComputationDataHandle(); @@ -1317,6 +1381,9 @@ ComputationDataHandle ComputationBuilder::Map( *request.add_operands() = operand; } *request.mutable_to_apply() = computation.handle(); + for (int64 dimension : dimensions) { + request.add_dimensions(dimension); + } for (const ComputationDataHandle& sop : static_operands) { *request.add_static_operands() = sop; } @@ -1429,10 +1496,20 @@ ComputationDataHandle ComputationBuilder::ReduceWindow( return ComputationDataHandle(); } - return ReduceWindowWithGeneralPadding( - operand, init_value, computation, window_dimensions, window_strides, + Status padding_valid = + ValidatePaddingValues(AsInt64Slice(shape.ValueOrDie()->dimensions()), + window_dimensions, window_strides); + if (!padding_valid.ok()) { + first_error_ = padding_valid; + return ComputationDataHandle(); + } + + std::vector> padding_values = MakePadding(AsInt64Slice(shape.ValueOrDie()->dimensions()), - window_dimensions, window_strides, padding)); + window_dimensions, window_strides, padding); + return ReduceWindowWithGeneralPadding(operand, init_value, computation, + window_dimensions, window_strides, + padding_values); } ComputationDataHandle ComputationBuilder::ReduceWindowWithGeneralPadding( @@ -1722,21 +1799,18 @@ StatusOr ComputationBuilder::Build() { void ComputationBuilder::AddCommonFieldsToOpRequest(OpRequest* request) const { *request->mutable_metadata() = metadata_; - *request->mutable_device_assignment() = device_assignment_; -} - -void ComputationBuilder::ClearDeviceAssignment() { device_assignment_.Clear(); } - -void ComputationBuilder::SetDeviceAssignment( - const OpDeviceAssignment& assignment) { - device_assignment_ = assignment; + if (sharding_) { + *request->mutable_sharding() = *sharding_; + } } /* static */ ConvolutionDimensionNumbers ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(kConvBatchDimension); - dimension_numbers.set_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_input_batch_dimension(kConvBatchDimension); + dimension_numbers.set_input_feature_dimension(kConvFeatureDimension); + dimension_numbers.set_output_batch_dimension(kConvBatchDimension); + dimension_numbers.set_output_feature_dimension(kConvFeatureDimension); dimension_numbers.set_kernel_output_feature_dimension( kConvKernelOutputDimension); dimension_numbers.set_kernel_input_feature_dimension( @@ -1750,15 +1824,17 @@ ComputationBuilder::CreateDefaultConvDimensionNumbers(int num_spatial_dims) { /* static */ StatusOr ComputationBuilder::CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial) { - if (std::set({batch, feature, first_spatial, second_spatial}).size() != - 4) { + if (std::set( + {input_batch, input_feature, first_spatial, second_spatial}) + .size() != 4) { return FailedPrecondition( "dimension numbers for the input are not unique: (%lld, %lld, %lld, " "%lld)", - batch, feature, first_spatial, second_spatial); + input_batch, input_feature, first_spatial, second_spatial); } if (std::set({kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial}) @@ -1769,9 +1845,19 @@ ComputationBuilder::CreateConvDimensionNumbers( kernel_output_feature, kernel_input_feature, kernel_first_spatial, kernel_second_spatial); } + if (std::set( + {output_batch, output_feature, first_spatial, second_spatial}) + .size() != 4) { + return FailedPrecondition( + "dimension numbers for the output are not unique: (%lld, %lld, %lld, " + "%lld)", + output_batch, output_feature, first_spatial, second_spatial); + } ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(batch); - dimension_numbers.set_feature_dimension(feature); + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); dimension_numbers.add_spatial_dimensions(first_spatial); dimension_numbers.add_spatial_dimensions(second_spatial); dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index b0e6720be2e542dd878e026b603e8570e1882b76..8e1b4be1f3ebf8e3f530b053447f86f7a2f56fa7 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -22,6 +22,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/array.h" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/array3d.h" #include "tensorflow/compiler/xla/array4d.h" @@ -42,6 +43,58 @@ limitations under the License. namespace xla { +class ShardingBuilder { + public: + // A shaped array used to describe the assignment of tiles to devices. + using TileAssignment = Array; + + // Creates a replicated sharding - replicate a tensor on every device. + static OpSharding Replicate() { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + return result; + } + // Creates a sharding that assigns a tensor to just one device. + static OpSharding AssignDevice(int device) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + result.add_tile_assignment_dimensions(1); + result.add_tile_assignment_devices(device); + return result; + } + // Creates a tiled sharding with the given tile shape and assignment of tiles + // to devices. + static OpSharding Tile(Shape tile_shape, + const TileAssignment& tile_assignment) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + for (int64 dim : tile_assignment.dimensions()) { + result.add_tile_assignment_dimensions(dim); + } + for (uint32 device : tile_assignment) { + result.add_tile_assignment_devices(device); + } + return result; + } + // Creates a sharding in one dimension, with the given tile shape which must + // be rank 1 and using devices 0..num_tiles. + static OpSharding Tile1D(Shape tile_shape, int64 num_tiles) { + OpSharding result; + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + + CHECK_EQ(ShapeUtil::Rank(tile_shape), 1); + std::vector dimensions(1, num_tiles); + auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; + tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); + *result.mutable_tile_shape() = tile_shape; + result.add_tile_assignment_dimensions(num_tiles); + for (int64 i = 0; i < num_tiles; ++i) { + result.add_tile_assignment_devices(i); + } + return result; + } +}; + // Wraps an XLA client with a convenient interface for building up // computations. Any errors encountered in building up the computation are // deferred from being handled until Build() is called. @@ -76,13 +129,17 @@ class ComputationBuilder { metadata_.Clear(); } - // Sets an OpDeviceAssignment that will be attached to all instructions - // until cleared. - void SetDeviceAssignment(const OpDeviceAssignment& assignment); + // Sets an OpSharding that will be attached to all instructions until cleared. + void SetSharding(const OpSharding& sharding) { sharding_ = sharding; } - // Clears the device assignment. Ops will be placed according to the default - // placement policy. - void ClearDeviceAssignment(); + // Clears the sharding. Ops will be sharded according to the default placement + // policy. + void ClearSharding() { sharding_ = tensorflow::gtl::nullopt; } + + // Returns the OpSharding that will be attached to all instructions. + const tensorflow::gtl::optional& sharding() const { + return sharding_; + } // Sets the builder to a mode where it will die immediately when an error is // encountered, rather than producing it in a deferred fashion when Build() is @@ -138,6 +195,11 @@ class ComputationBuilder { ComputationDataHandle ConstantR2( std::initializer_list> values); template + ComputationDataHandle ConstantFromArrayWithLayout( + const Array& values, const Layout& layout); + template + ComputationDataHandle ConstantFromArray(const Array& values); + template ComputationDataHandle ConstantR2FromArray2DWithLayout( const Array2D& values, const Layout& layout); template @@ -201,6 +263,16 @@ class ComputationBuilder { // {x=1024, y=32} by collapsing dims {0, 1, 2}. Collapsing dimensions must // be a consecutive, in-order subsequence of the operand dimensions. // + // Note that collapsing a single dimension does nothing: + // + // {256} collapsing {0} => {256} + // {1} collapsing {0} => {1} + // + // Collapsing multiple dimensions produces a single result dimension: + // + // {256, 2} collapsing {0,1} => {512} + // {256, 2, 3} collapsing {0,1} => {512, 3} + // // This could potentially cause data to be moved -- it provides a more // structured form of reshaping than an arbitrary Reshape operation. ComputationDataHandle Collapse(const ComputationDataHandle& operand, @@ -344,7 +416,8 @@ class ComputationBuilder { // Creates a ConvolutionDimensionNumbers with the given arguments. Returns an // error if either the input or the weight dimension numbers have conflicts. static StatusOr CreateConvDimensionNumbers( - int64 batch, int64 feature, int64 first_spatial, int64 second_spatial, + int64 input_batch, int64 input_feature, int64 output_batch, + int64 output_feature, int64 first_spatial, int64 second_spatial, int64 kernel_output_feature, int64 kernel_input_feature, int64 kernel_first_spatial, int64 kernel_second_spatial); @@ -415,6 +488,14 @@ class ComputationBuilder { // of the operands is a scalar, or an explicit broadcast dimension is given // (see g3doc for more details). + // Enqueues a complex compose instruction onto the computation. + ComputationDataHandle Complex( + const ComputationDataHandle& real, const ComputationDataHandle& imag, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + + // Enqueues a complex conjugate instruction onto the computation. + ComputationDataHandle Conj(const ComputationDataHandle& operand); + // Enqueues an add instruction onto the computation. ComputationDataHandle Add( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, @@ -451,15 +532,25 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Element-wise logical operators - ComputationDataHandle LogicalAnd( + ComputationDataHandle And( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalOr( + ComputationDataHandle Or( const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, tensorflow::gtl::ArraySlice broadcast_dimensions = {}); - ComputationDataHandle LogicalNot(const ComputationDataHandle& lhs); + ComputationDataHandle Not(const ComputationDataHandle& operand); + + ComputationDataHandle ShiftLeft( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightArithmetic( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + ComputationDataHandle ShiftRightLogical( + const ComputationDataHandle& lhs, const ComputationDataHandle& rhs, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); // Reduces an array among the provided dimensions, given "computation" as a // reduction operator. @@ -516,6 +607,11 @@ class ComputationBuilder { // Enqueues an abs instruction onto the computation. ComputationDataHandle Abs(const ComputationDataHandle& operand); + // Enqueues a atan2 instruction onto the computation. + ComputationDataHandle Atan2( + const ComputationDataHandle& y, const ComputationDataHandle& x, + tensorflow::gtl::ArraySlice broadcast_dimensions = {}); + // Enqueues an exp instruction onto the computation. ComputationDataHandle Exp(const ComputationDataHandle& operand); @@ -544,6 +640,12 @@ class ComputationBuilder { // Enqueues a tanh instruction onto the computation. ComputationDataHandle Tanh(const ComputationDataHandle& operand); + // Enqueues a real-part instruction onto the computation. + ComputationDataHandle Real(const ComputationDataHandle& operand); + + // Enqueues an imaginary-part instruction onto the computation. + ComputationDataHandle Imag(const ComputationDataHandle& operand); + // Enqueues a float32 sqrt instruction onto the computation. // (float32 is specified as there is an implicit float32 0.5f constant // exponent). @@ -604,6 +706,7 @@ class ComputationBuilder { ComputationDataHandle Map( tensorflow::gtl::ArraySlice operands, const Computation& computation, + tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice static_operands = {}); // Enqueues a N(mu, sigma) random number generation instruction onto the @@ -643,11 +746,12 @@ class ComputationBuilder { ComputationDataHandle Recv(const Shape& shape, const ChannelHandle& handle); // Returns true if 'operand' is a compile-time constant. A compile-time - // constant does not depend on parameters, or on stateful operators such - // as `RngNormal` or `Infeed`. Unlike `ComputeConstant`, `IsConstant` tests - // whether a computation is a compile-time constant without evaluating the - // computation. - StatusOr IsConstant(const ComputationDataHandle& operand); + // constant does not depend on parameters with higher index then + // `num_parameters`, or on stateful operators such as `RngNormal` or `Infeed`. + // Unlike `ComputeConstant`, `IsConstant` tests whether a computation is a + // compile-time constant without evaluating the computation. + StatusOr IsConstant(const ComputationDataHandle& operand, + int64 num_parameters = 0); // Normalizes operand across spatial and batch dimensions for each feature. // @@ -692,7 +796,7 @@ class ComputationBuilder { float epsilon, int64 feature_index); // Computes the value of a constant indicated by a - // ComputationDataHandle. + // ComputationDataHandle using a non-optimized interpreter on the host. // // The operand must be from the computation currently being built - // i.e., returned from this builder with no intervening call to @@ -700,8 +804,11 @@ class ComputationBuilder { // that may stop working at any time. // // The operand must represent a constant value, which in this case - // means that it must not statically depend on a parameter to the - // computation that is being built. + // means that it must not statically depend on any parameter of the + // computation that is being built other then the ones specified on the + // paramtere list. The parameters in the list will be indexed by their + // parameter id property so the number of parameters specified should be at + // least as many as the largest used parameter index. // // `IsConstant` can be used to test whether a computation is a compile-time // constant without evaluation it. `ComputeConstant` only succeeds for @@ -719,7 +826,8 @@ class ComputationBuilder { // will be stored using that layout. StatusOr> ComputeConstant( const ComputationDataHandle& operand, - const Layout* output_layout = nullptr); + const Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}); // Returns a new ComputationBuilder whose resultant Computation is used only // by this ComputationBuilder. The sub-ComputationBuilder has the same @@ -848,8 +956,9 @@ class ComputationBuilder { // throughout the TensorFlow op kernel implementations). OpMetadata metadata_; - // Device assignment for the operator. - OpDeviceAssignment device_assignment_; + // Sharding for this operator. This is structured as a "model"-like operation, + // in order to simplify client code, similar to metadata_. + tensorflow::gtl::optional sharding_; TF_DISALLOW_COPY_AND_ASSIGN(ComputationBuilder); }; @@ -888,50 +997,83 @@ ComputationDataHandle ComputationBuilder::ConstantR2( } template -ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +ComputationDataHandle ComputationBuilder::ConstantFromArrayWithLayout( + const Array& values, const Layout& layout) { return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); }); } +template +ComputationDataHandle ComputationBuilder::ConstantFromArray( + const Array& values) { + return ConstantOp( + [&values](Literal* literal) { literal->PopulateFromArray(values); }); +} + +template +ComputationDataHandle ComputationBuilder::ConstantR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return ConstantFromArrayWithLayout(values, layout); +} + template ComputationDataHandle ComputationBuilder::ConstantR2FromArray2D( const Array2D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR2FromArray2D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR3FromArray3DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR3FromArray3D( const Array3D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR3FromArray3D(values); }); + return ConstantFromArray(values); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - return ConstantOp([&values, &layout](Literal* literal) { - literal->PopulateR4FromArray4DWithLayout(values, layout); - }); + return ConstantFromArrayWithLayout(values, layout); } template ComputationDataHandle ComputationBuilder::ConstantR4FromArray4D( const Array4D& values) { - return ConstantOp( - [&values](Literal* literal) { literal->PopulateR4FromArray4D(values); }); + return ConstantFromArray(values); } +// RAII-style object: sets the current sharding assignment in builder on +// construction, and sets back to the previous assignment on destruction. +class ScopedShardingAssignment { + public: + ScopedShardingAssignment(xla::ComputationBuilder* builder, + tensorflow::gtl::optional sharding) + : builder_(builder), prev_sharding_(builder->sharding()) { + SetSharding(sharding); + } + + ~ScopedShardingAssignment() { SetSharding(prev_sharding_); } + + private: + void SetSharding(const tensorflow::gtl::optional& sharding) { + if (sharding.has_value()) { + builder_->SetSharding(sharding.value()); + } else { + builder_->ClearSharding(); + } + } + + xla::ComputationBuilder* const builder_; + tensorflow::gtl::optional prev_sharding_; + + TF_DISALLOW_COPY_AND_ASSIGN(ScopedShardingAssignment); +}; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_CLIENT_COMPUTATION_BUILDER_H_ diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.cc b/tensorflow/compiler/xla/client/lib/arithmetic.cc index 969b0eee1d195a36728f16a598add4b3b850ed60..24048a1e5a782661ba577ba50e3b5b2914f17c0a 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.cc +++ b/tensorflow/compiler/xla/client/lib/arithmetic.cc @@ -89,24 +89,24 @@ Computation CreateScalarMinComputation(PrimitiveType type, const ComputationDataHandle& rhs) { return b->Min(lhs, rhs); }); } -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder) { +Computation CreateScalarAndComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_and", PRED, builder, + "and", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalAnd(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->And(lhs, rhs); }); } -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder) { +Computation CreateScalarOrComputation(ComputationBuilder* builder) { return CreateScalarComputation( - "logical_or", PRED, builder, + "or", PRED, builder, [](ComputationBuilder* b, const ComputationDataHandle& lhs, - const ComputationDataHandle& rhs) { return b->LogicalOr(lhs, rhs); }); + const ComputationDataHandle& rhs) { return b->Or(lhs, rhs); }); } StatusOr Any(const ComputationDataHandle& predicates, ComputationBuilder* builder) { auto f = builder->ConstantR0(false); - Computation logical_or = CreateScalarLogicalOrComputation(builder); + Computation logical_or = CreateScalarOrComputation(builder); TF_ASSIGN_OR_RETURN(std::unique_ptr predicates_shape, builder->GetShape(predicates)); std::vector all_dimensions(ShapeUtil::Rank(*predicates_shape)); diff --git a/tensorflow/compiler/xla/client/lib/arithmetic.h b/tensorflow/compiler/xla/client/lib/arithmetic.h index f43d35fe4a52016d4054af28835d6b66a35217d4..ae89784bc227d837cf15f0a89687dd00dccc2745 100644 --- a/tensorflow/compiler/xla/client/lib/arithmetic.h +++ b/tensorflow/compiler/xla/client/lib/arithmetic.h @@ -45,10 +45,10 @@ Computation CreateScalarMinComputation(PrimitiveType type, ComputationBuilder* builder); // Creates a scalar logical AND computation and returns it. -Computation CreateScalarLogicalAndComputation(ComputationBuilder* builder); +Computation CreateScalarAndComputation(ComputationBuilder* builder); // Creates a scalar logical OR computation and returns it. -Computation CreateScalarLogicalOrComputation(ComputationBuilder* builder); +Computation CreateScalarOrComputation(ComputationBuilder* builder); // Returns whether any predicate in "predicates" is set. // diff --git a/tensorflow/compiler/xla/client/lib/testing.cc b/tensorflow/compiler/xla/client/lib/testing.cc index 482d53cf330f152f496b77233714f93991fef6f0..e6645e4941bd04c658b67117bb689f6fdef7dfc1 100644 --- a/tensorflow/compiler/xla/client/lib/testing.cc +++ b/tensorflow/compiler/xla/client/lib/testing.cc @@ -79,6 +79,24 @@ StatusOr> MakeFakeLiteral(const Shape& shape) { })); break; } + case S64: { + std::uniform_int_distribution generator( + std::numeric_limits::lowest(), + std::numeric_limits::max()); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } + case PRED: { + std::uniform_int_distribution generator(0, 1); + TF_CHECK_OK(literal->Populate( + [&](tensorflow::gtl::ArraySlice /*indices*/) { + return generator(engine); + })); + break; + } default: return Unimplemented("Unsupported type for fake literal generation: %s", ShapeUtil::HumanString(shape).c_str()); diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc index a0fc230319374afc6746e297e817363291db672b..15c744ecd349e91dc703bec5708d78a896f132c3 100644 --- a/tensorflow/compiler/xla/client/local_client.cc +++ b/tensorflow/compiler/xla/client/local_client.cc @@ -169,16 +169,21 @@ tensorflow::Status LocalExecutable::ValidateExecutionOptions( return Status::OK(); } -StatusOr> LocalExecutable::Run( +StatusOr> LocalExecutable::Run( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& options) { TF_RETURN_IF_ERROR(ValidateExecutionOptions(arguments, options, *backend_)); ExecutableRunOptions actual_options = options; + + Backend::StreamPtr stream; if (options.stream() == nullptr) { + // NB! The lifetime of `stream` needs to match the lifetime of + // `actual_options` (otherwise we will end up using a returned stream in + // ExecuteOnStreamWrapper), which is why it isn't declared in the inner "if" + // scope. TF_ASSIGN_OR_RETURN( - Backend::StreamPtr stream, - BorrowStreamForDevice(options.device_ordinal(), backend_)); + stream, BorrowStreamForDevice(options.device_ordinal(), backend_)); actual_options.set_stream(stream.get()); } if (options.allocator() == nullptr) { @@ -197,11 +202,15 @@ StatusOr> LocalExecutable::Run( if (executable_->dumping()) { return ExecuteAndDump(&service_options, arguments); } - return executable_->ExecuteOnStreamWrapper>( - &service_options, options.execution_profile(), arguments); + TF_ASSIGN_OR_RETURN( + std::unique_ptr result, + executable_->ExecuteOnStreamWrapper>( + &service_options, options.execution_profile(), arguments)); + return ScopedShapedBuffer::MakeScoped(result.get(), + actual_options.allocator()); } -StatusOr> LocalExecutable::ExecuteAndDump( +StatusOr> LocalExecutable::ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments) { executable_->session_module()->set_execution_platform( @@ -213,7 +222,7 @@ StatusOr> LocalExecutable::ExecuteAndDump( /*hlo_execution_profile=*/nullptr)); TF_RETURN_IF_ERROR(RecordResult(result.get(), executable_->session_module())); TF_RETURN_IF_ERROR(executable_->DumpSessionModule()); - return std::move(result); + return ScopedShapedBuffer::MakeScoped(result.get(), run_options->allocator()); } tensorflow::Status LocalExecutable::RecordArguments( @@ -279,11 +288,10 @@ StatusOr> LocalClient::Compile( int device_ordinal = options.device_ordinal() == -1 ? default_device_ordinal() : options.device_ordinal(); - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - local_service_->CompileExecutable(computation.handle(), argument_layouts, - options.result_layout(), device_ordinal, - options.has_hybrid_result())); + TF_ASSIGN_OR_RETURN(std::unique_ptr executable, + local_service_->CompileExecutable( + computation.handle(), argument_layouts, + options.result_layout(), device_ordinal)); return WrapUnique(new LocalExecutable(std::move(executable), local_service_->mutable_backend(), device_ordinal, options)); @@ -293,12 +301,14 @@ StatusOr> LocalClient::Compile( // ScopedShapedBuffer. The given memory allocator is used for device memory // allocation. StatusOr> -LocalClient::LiteralToShapedBuffer(const Literal& literal, - DeviceMemoryAllocator* allocator, - int device_ordinal) { - TF_ASSIGN_OR_RETURN(auto scoped_buffer, - ScopedShapedBuffer::MakeScopedShapedBuffer( - literal.shape(), allocator, device_ordinal)); +LocalClient::LiteralToShapedBuffer(const Literal& literal, int device_ordinal, + DeviceMemoryAllocator* allocator) { + if (allocator == nullptr) { + allocator = backend().memory_allocator(); + } + TF_ASSIGN_OR_RETURN( + auto scoped_buffer, + ScopedShapedBuffer::Allocate(literal.shape(), allocator, device_ordinal)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, backend().stream_executor(device_ordinal)); TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus( diff --git a/tensorflow/compiler/xla/client/local_client.h b/tensorflow/compiler/xla/client/local_client.h index e98384238a838461ee37d0e14d9a80fd2be7c33b..9f985ed5275815de2d59f6caedbbcc8060420a13 100644 --- a/tensorflow/compiler/xla/client/local_client.h +++ b/tensorflow/compiler/xla/client/local_client.h @@ -79,7 +79,7 @@ class LocalExecutable { public: // Run the compiled computation with the given arguments and options and // return the result. - StatusOr> Run( + StatusOr> Run( const tensorflow::gtl::ArraySlice arguments, const ExecutableRunOptions& options); @@ -115,7 +115,7 @@ class LocalExecutable { // Records the computation in a SessionModule proto with the arguments used to // invoke it, and the result. Enabled by flag: --tla_dump_executions_to. - StatusOr> ExecuteAndDump( + StatusOr> ExecuteAndDump( const ServiceExecutableRunOptions* run_options, const tensorflow::gtl::ArraySlice arguments); @@ -166,11 +166,12 @@ class LocalClient : public Client { const ExecutableBuildOptions& options); // Copy the literal data to the device with the given ordinal and return as a - // ScopedShapedBuffer. The given memory allocator is used for device memory - // allocation. + // ScopedShapedBuffer. If non-null the given memory allocator is used for + // device memory allocation. If null, the default memory allocator for the + // device is used. StatusOr> LiteralToShapedBuffer( - const Literal& literal, DeviceMemoryAllocator* allocator, - int device_ordinal); + const Literal& literal, int device_ordinal, + DeviceMemoryAllocator* allocator = nullptr); // Copy the data from the device contained in the given ShapedBuffer and // return as a Literal. diff --git a/tensorflow/compiler/xla/client/padding.cc b/tensorflow/compiler/xla/client/padding.cc index 0b18d8946a2e62a810f875b4d79fd5375e787487..6a9cf466ac0a43ce214ef0e6aae9e6295f137b0f 100644 --- a/tensorflow/compiler/xla/client/padding.cc +++ b/tensorflow/compiler/xla/client/padding.cc @@ -17,17 +17,34 @@ limitations under the License. #include +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/math/math_util.h" #include "tensorflow/core/platform/logging.h" namespace xla { +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides) { + bool ok = input_dimensions.size() == window_dimensions.size() && + input_dimensions.size() == window_strides.size(); + if (!ok) { + return InvalidArgument( + "Want input dimensions size %zu = window dimensions size %zu = window " + "strides size %zu", + input_dimensions.size(), window_dimensions.size(), + window_strides.size()); + } + return Status::OK(); +} + std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, tensorflow::gtl::ArraySlice window_strides, Padding padding) { - CHECK_EQ(input_dimensions.size(), window_dimensions.size()); - CHECK_EQ(input_dimensions.size(), window_strides.size()); + TF_CHECK_OK(ValidatePaddingValues(input_dimensions, window_dimensions, + window_strides)); std::vector> low_high_padding; switch (padding) { case Padding::kValid: diff --git a/tensorflow/compiler/xla/client/padding.h b/tensorflow/compiler/xla/client/padding.h index dce2d87e8da8b3d9fd138a712c459ea0081372e0..e23b0b3a90a091bf80973525810793c3eda4a036 100644 --- a/tensorflow/compiler/xla/client/padding.h +++ b/tensorflow/compiler/xla/client/padding.h @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -37,6 +38,14 @@ enum class Padding { kValid, }; +// Validates that the slices are acceptable for determining padding -- this can +// be used to check the preconditions of MakePadding below to produce an error +// message that can be returned to the user. +Status ValidatePaddingValues( + tensorflow::gtl::ArraySlice input_dimensions, + tensorflow::gtl::ArraySlice window_dimensions, + tensorflow::gtl::ArraySlice window_strides); + // Returns the padding needed for the base area, given the base area dimensions, // window dimensions, strides, and the type of padding. // @@ -51,7 +60,7 @@ enum class Padding { std::vector> MakePadding( tensorflow::gtl::ArraySlice input_dimensions, tensorflow::gtl::ArraySlice window_dimensions, - tensorflow::gtl::ArraySlice strides, Padding padding); + tensorflow::gtl::ArraySlice window_strides, Padding padding); } // namespace xla diff --git a/tensorflow/compiler/xla/iterator_util.h b/tensorflow/compiler/xla/iterator_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a39999705eddc5728dce028dab64b7358395757e --- /dev/null +++ b/tensorflow/compiler/xla/iterator_util.h @@ -0,0 +1,98 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ + +#include +#include + +namespace xla { + +// UnwrappingIterator is a transforming iterator that calls get() on the +// elements it returns. +// +// Together with tensorflow::gtl::iterator_range, this lets classes which +// contain a collection of smart pointers expose a view of raw pointers to +// consumers. For example: +// +// class MyContainer { +// public: +// tensorflow::gtl::iterator_range< +// UnwrappingIterator>::iterator>> +// things() { +// return {MakeUnwrappingIterator(things_.begin()), +// MakeUnwrappingIterator(things_.end())}; +// } +// +// tensorflow::gtl::iterator_range>::const_iterator>> +// things() const { +// return {MakeUnwrappingIterator(things_.begin()), +// MakeUnwrappingIterator(things_.end())}; +// } +// +// private: +// std::vector> things_; +// }; +// +// MyContainer container = ...; +// for (Thing* t : container.things()) { +// ... +// } +// +// For simplicity, UnwrappingIterator is currently unconditionally an +// input_iterator -- it doesn't inherit any superpowers NestedIterator may have. +template +class UnwrappingIterator + : public std::iterator()->get())> { + private: + NestedIter iter_; + + public: + explicit UnwrappingIterator(NestedIter iter) : iter_(std::move(iter)) {} + + auto operator*() -> decltype(iter_->get()) { return iter_->get(); } + auto operator-> () -> decltype(iter_->get()) { return iter_->get(); } + UnwrappingIterator& operator++() { + ++iter_; + return *this; + } + UnwrappingIterator operator++(int) { + UnwrappingIterator temp(iter_); + operator++(); + return temp; + } + + friend bool operator==(const UnwrappingIterator& a, + const UnwrappingIterator& b) { + return a.iter_ == b.iter_; + } + + friend bool operator!=(const UnwrappingIterator& a, + const UnwrappingIterator& b) { + return !(a == b); + } +}; + +template +UnwrappingIterator MakeUnwrappingIterator(NestedIter iter) { + return UnwrappingIterator(std::move(iter)); +} + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_ITERATOR_UTIL_H_ diff --git a/tensorflow/compiler/xla/iterator_util_test.cc b/tensorflow/compiler/xla/iterator_util_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7bc3189507ec5233c6983eb26cfb07dc9bfadd52 --- /dev/null +++ b/tensorflow/compiler/xla/iterator_util_test.cc @@ -0,0 +1,62 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/iterator_util.h" + +#include +#include + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/test.h" + +namespace xla { +namespace { + +TEST(UnwrappingIteratorTest, Simple) { + std::vector> v; + for (int i = 0; i < 3; ++i) { + v.push_back(MakeUnique(i)); + } + int i = 0; + for (auto iter = MakeUnwrappingIterator(v.begin()); + iter != MakeUnwrappingIterator(v.end()); ++iter) { + EXPECT_EQ(*iter, v[i].get()); + ++i; + } +} + +TEST(UnwrappingIteratorTest, PostincrementOperator) { + std::vector> v; + for (int i = 0; i < 3; ++i) { + v.push_back(std::make_shared(i)); + } + auto iter = MakeUnwrappingIterator(v.begin()); + EXPECT_EQ(*(iter++), v[0].get()); + EXPECT_EQ(*iter, v[1].get()); +} + +// std::find relies on various iterator traits being properly defined. +TEST(UnwrappingIteratorTest, StdFind) { + std::list> l; + for (int i = 0; i < 3; ++i) { + l.push_back(MakeUnique(i)); + } + EXPECT_EQ(l.begin()->get(), + *std::find(MakeUnwrappingIterator(l.begin()), + MakeUnwrappingIterator(l.end()), l.begin()->get())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index 011fc3c194e0eb9ebd6b9e42571deddaf25c09ff..5c2cc2a7a99cc51ded3d98c9dd5903e4b3078548 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -83,6 +83,10 @@ Layout CreateDefaultLayoutForRank(int64 rank) { return CreateDefaultLayoutForRank(shape.dimensions_size()); } +/* static */ Layout LayoutUtil::GetDefaultLayoutForRank(int64 rank) { + return CreateDefaultLayoutForRank(rank); +} + /* static */ Layout LayoutUtil::GetDefaultLayoutForR2() { return CreateDefaultLayoutForRank(2); } diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index 5de0a653f66688ac75fc377c18ff93012314abdd..bc42e222292933be35e82d1fe50802e8830d16b3 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,7 @@ class LayoutUtil { static Layout GetDefaultLayoutForShape(const Shape& shape); // Helper functions that create default layouts for various ranks. + static Layout GetDefaultLayoutForRank(int64 rank); static Layout GetDefaultLayoutForR2(); static Layout GetDefaultLayoutForR3(); static Layout GetDefaultLayoutForR4(); diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc index 8892bfbe929d168c602af24cfbb507256dc05328..f2cdd9669c727bb778fce495ede0faaf2d9a923d 100644 --- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc +++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc @@ -206,9 +206,9 @@ void AllocateFlags() { flag_values->xla_gpu_disable_multi_streaming(), "If true, multi-streaming in the GPU backend is disabled."), tensorflow::Flag( - "xla_dump_debug_json_to", - flag_values->mutable_xla_dump_debug_json_to(), - "Dump compilation artifacts as JSON into this directory."), + "xla_dump_hlo_proto_to", + flag_values->mutable_xla_dump_hlo_proto_to(), + "Dump compilation artifacts as proto binary into this directory."), tensorflow::Flag( "xla_test_all_output_layouts", bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts), diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc index 79e40c12625c41b7234542381d0ca528be7eaed4..fda791401d567b694b3d2cabf129141a7ff2ddb2 100644 --- a/tensorflow/compiler/xla/literal_util.cc +++ b/tensorflow/compiler/xla/literal_util.cc @@ -173,6 +173,8 @@ Status Literal::Copy(const Literal& src_literal, return CopyRange(src_literal, src_base, dest_base, copy_size); case F64: return CopyRange(src_literal, src_base, dest_base, copy_size); + case C64: + return CopyRange(src_literal, src_base, dest_base, copy_size); case PRED: return CopyRange(src_literal, src_base, dest_base, copy_size); default: @@ -202,6 +204,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(0); case F64: return *Literal::CreateR0(0); + case C64: + return *Literal::CreateR0(0); case PRED: return *Literal::CreateR0(false); case S16: @@ -234,6 +238,8 @@ Status Literal::Copy(const Literal& src_literal, return *Literal::CreateR0(1); case F64: return *Literal::CreateR0(1); + case C64: + return *Literal::CreateR0(1); case PRED: return *Literal::CreateR0(true); case S16: @@ -269,6 +275,8 @@ Status Literal::Copy(const Literal& src_literal, case F64: return *Literal::CreateR0( -std::numeric_limits::infinity()); + case C64: + LOG(FATAL) << "C64 element type has no minimum value"; case PRED: return *Literal::CreateR0(false); case S16: @@ -522,6 +530,10 @@ string Literal::GetAsString( return tensorflow::strings::StrCat(Get(multi_index)); case F64: return tensorflow::strings::StrCat(Get(multi_index)); + case C64: { + complex64 c = Get(multi_index); + return tensorflow::strings::StrCat("(", c.real(), ", ", c.imag(), ")"); + } case F16: return tensorflow::strings::StrCat(Get(multi_index)); default: @@ -575,11 +587,11 @@ string Literal::ToString() const { if (ShapeUtil::IsTuple(shape())) { pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" (\n"); - for (const auto& element_literal : tuple_literals()) { - pieces.push_back(element_literal.ToString()); - pieces.push_back(",\n"); - } - pieces.push_back(")"); + pieces.push_back(tensorflow::str_util::Join( + tuple_literals(), ",\n", [](string* out, const Literal& element) { + tensorflow::strings::StrAppend(out, element.ToString()); + })); + pieces.push_back("\n)"); } else if (ShapeUtil::Rank(shape()) == 0) { pieces.push_back(GetAsString({})); } else if (ShapeUtil::Rank(shape()) == 1) { @@ -597,7 +609,7 @@ string Literal::ToString() const { pieces.push_back(element_to_string({i0, i1})); } pieces.push_back(" "); - pieces.push_back("},\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? "}\n" : "},\n"); } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 3) { @@ -619,45 +631,48 @@ string Literal::ToString() const { pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( - tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back(" {"); for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(element_to_string({i0, i1, i2, i3})); } - pieces.push_back("},\n"); + pieces.push_back(i2 == shape().dimensions(2) - 1 ? "}\n" : "},\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); } pieces.push_back("}"); } else if (ShapeUtil::Rank(shape()) == 5) { pieces.push_back(ShapeUtil::HumanString(shape())); pieces.push_back(" {\n"); for (int64 i0 = 0; i0 < shape().dimensions(0); ++i0) { - pieces.push_back(tensorflow::strings::Printf(" { // i0=%lld\n", i0)); + pieces.push_back(tensorflow::strings::Printf(" { /*i0=%lld*/\n", i0)); for (int64 i1 = 0; i1 < shape().dimensions(1); ++i1) { pieces.push_back( - tensorflow::strings::Printf(" { // i1=%lld\n", i1)); + tensorflow::strings::Printf(" { /*i1=%lld*/\n", i1)); for (int64 i2 = 0; i2 < shape().dimensions(2); ++i2) { pieces.push_back( - tensorflow::strings::Printf(" { // i2=%lld\n", i2)); + tensorflow::strings::Printf(" { /*i2=%lld*/\n", i2)); for (int64 i3 = 0; i3 < shape().dimensions(3); ++i3) { pieces.push_back(" {"); for (int64 i4 = 0; i4 < shape().dimensions(4); ++i4) { pieces.push_back(element_to_string({i0, i1, i2, i3, i4})); } - pieces.push_back("},\n"); + pieces.push_back(i3 == shape().dimensions(3) - 1 ? "}\n" : "},\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i2 == shape().dimensions(2) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i1 == shape().dimensions(1) - 1 ? " }\n" + : " },\n"); } - pieces.push_back(" },\n"); + pieces.push_back(i0 == shape().dimensions(0) - 1 ? " }\n" : " },\n"); } pieces.push_back("}"); } else { @@ -716,6 +731,8 @@ void* Literal::MutableInternalData() { return reinterpret_cast(f32s_.data()); case F64: return reinterpret_cast(f64s_.data()); + case C64: + return reinterpret_cast(c64s_.data()); case F16: return reinterpret_cast(f16s_.data()); default: @@ -754,6 +771,9 @@ void Literal::Reserve(int64 num_elements) { case F64: Resize(num_elements, 0); break; + case C64: + Resize(num_elements, 0); + break; case F16: Resize(num_elements, static_cast(0.0f)); break; @@ -790,6 +810,9 @@ tensorflow::Status Literal::ValidateLiteral() const { case F64: actual = f64s_size(); break; + case C64: + actual = c64s_size(); + break; case F16: actual = f16s().size() / sizeof(half); break; @@ -843,6 +866,26 @@ std::unique_ptr ConvertBetweenNativeTypes(const Literal& src_literal) { return result_literal; } +template +std::unique_ptr ConvertToC64(const Literal& src_literal) { + auto result_literal = MakeUnique(); + Shape* result_shape = result_literal->mutable_shape(); + *result_shape = src_literal.shape(); + result_shape->set_element_type(C64); + result_literal->Reserve(ShapeUtil::ElementsIn(*result_shape)); + using NativeSrcT = + typename primitive_util::PrimitiveTypeToNative::type; + tensorflow::gtl::ArraySlice src_data = + src_literal.GetArraySlice(); + tensorflow::gtl::MutableArraySlice dest_data = + result_literal->GetMutableArraySlice(); + int64 num_elements = ShapeUtil::ElementsIn(src_literal.shape()); + for (int64 i = 0; i < num_elements; ++i) { + dest_data[i] = complex64(static_cast(src_data[i]), 0); + } + return result_literal; +} + template std::unique_ptr ConvertIfTypesMatch(const Literal& src_literal) { CHECK_EQ(primitive_src_type, src_literal.shape().element_type()); @@ -870,6 +913,8 @@ StatusOr> ConvertIfDestTypeMatches( CONVERT_IF_TYPES_MATCH(F32) CONVERT_IF_TYPES_MATCH(F64) #undef CONVERT_IF_TYPES_MATCH + case C64: + return ConvertToC64(src_literal); // Other types are not yet supported. default: return InvalidArgument( @@ -966,6 +1011,8 @@ bool Literal::operator==(const Literal& other) const { return EqualElements(*this, other, 0, &multi_index); case F16: return EqualElements(*this, other, 0, &multi_index); + case C64: + return EqualElements(*this, other, 0, &multi_index); default: LOG(FATAL) << "Unimplemented: Literal::Equal for type " << PrimitiveType_Name(shape().element_type()); @@ -1065,6 +1112,12 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { values->size()); } +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { + auto values = mutable_c64s(); + return {values->data(), values->size()}; +} + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice() { // TODO - there is an endianess problem here. fix it, or wait for uint16 @@ -1144,6 +1197,13 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const { f16s().size() / sizeof(half)); } +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const { + CHECK_EQ(shape().element_type(), C64); + return c64s(); +} + template static bool AllElementsEqualValue(const Literal& literal, NativeT value) { for (int64 i = 0; i < ShapeUtil::ElementsIn(literal.shape()); ++i) { @@ -1211,6 +1271,15 @@ bool Literal::IsAllFloat(float value) const { } } +bool Literal::IsAllComplex(complex64 value) const { + switch (shape().element_type()) { + case C64: + return AllElementsEqualValue(*this, value); + default: + return false; + } +} + bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { switch (shape().element_type()) { case U8: @@ -1229,6 +1298,8 @@ bool Literal::IsZero(tensorflow::gtl::ArraySlice indices) const { return Get(indices) == 0.0f; case F64: return Get(indices) == 0.0; + case C64: + return Get(indices) == complex64(0.0f, 0.0f); case F16: return Get(indices) == static_cast(0.0f); case PRED: @@ -1298,12 +1369,27 @@ void Literal::Resize(int64 num_elements, half value) { mutable_f16s()->resize(num_elements, value); } +template <> +void Literal::Resize(int64 num_elements, complex64 value) { + CHECK_EQ(ShapeUtil::ElementsIn(shape()), num_elements); + mutable_c64s()->resize(num_elements, value); +} + template -static void CopyToRepeatedField(RepeatedFieldT* dest, - const std::vector& src) { +void CopyToRepeatedField(RepeatedFieldT* dest, + const std::vector& src) { *dest = RepeatedFieldT(src.begin(), src.end()); } +template <> +void CopyToRepeatedField, complex64>( + tensorflow::protobuf::RepeatedField* dest, + const std::vector& src) { + *dest = tensorflow::protobuf::RepeatedField( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() * 2); +} + LiteralProto Literal::ToProto() const { LiteralProto proto; proto.Clear(); @@ -1338,6 +1424,9 @@ LiteralProto Literal::ToProto() const { case F64: CopyToRepeatedField(proto.mutable_f64s(), f64s()); break; + case C64: + CopyToRepeatedField(proto.mutable_c64s(), c64s()); + break; case TUPLE: for (const auto& tuple : tuple_literals()) { *proto.add_tuple_literals() = tuple.ToProto(); @@ -1351,11 +1440,21 @@ LiteralProto Literal::ToProto() const { } template -static void CopyFromRepeatedField(std::vector* dest, - const RepeatedFieldT& src) { +void CopyFromRepeatedField(std::vector* dest, + const RepeatedFieldT& src) { *dest = std::vector(src.begin(), src.end()); } +template <> +void CopyFromRepeatedField, + complex64>( + std::vector* dest, + const tensorflow::protobuf::RepeatedField& src) { + *dest = std::vector( + reinterpret_cast(src.data()), + reinterpret_cast(src.data()) + src.size() / 2); +} + void Literal::CopyFromProto(const LiteralProto& literal_proto) { if (!literal_proto.has_shape()) { return; @@ -1394,6 +1493,9 @@ void Literal::CopyFromProto(const LiteralProto& literal_proto) { case F64: CopyFromRepeatedField(mutable_f64s(), literal_proto.f64s()); break; + case C64: + CopyFromRepeatedField(mutable_c64s(), literal_proto.c64s()); + break; case TUPLE: for (const auto& proto : literal_proto.tuple_literals()) { mutable_tuple_literals()->push_back(Literal(proto)); diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h index e8cee732d4cf5629c1e2b9c98d1f1bbe1e29a122..a1e288829f22835f94c6e3c041796f84d995211c 100644 --- a/tensorflow/compiler/xla/literal_util.h +++ b/tensorflow/compiler/xla/literal_util.h @@ -159,6 +159,10 @@ class Literal { const std::vector& f64s() const { return f64s_; } std::vector* mutable_f64s() { return &f64s_; } + int c64s_size() const { return c64s().size(); } + const std::vector& c64s() const { return c64s_; } + std::vector* mutable_c64s() { return &c64s_; } + int tuple_literals_size() const { return tuple_literals().size(); } const Literal& tuple_literals(int i) const { return tuple_literals_[i]; } Literal* add_tuple_literals() { @@ -334,6 +338,11 @@ class Literal { // WithLayout use the default XLA layout for the literal's linear // representation in memory. template + static std::unique_ptr CreateFromArray(const Array& values); + template + static std::unique_ptr CreateFromArrayWithLayout( + const Array& values, const Layout& layout); + template static std::unique_ptr CreateR2FromArray2D( const Array2D& values); template @@ -481,6 +490,11 @@ class Literal { std::initializer_list> values, const Layout& layout); template + void PopulateFromArray(const Array& values); + template + void PopulateFromArrayWithLayout(const Array& values, + const Layout& layout); + template void PopulateR2FromArray2D(const Array2D& values); template void PopulateR2FromArray2DWithLayout(const Array2D& values, @@ -550,6 +564,17 @@ class Literal { // e.g. -0.5. bool IsAllFloat(float value) const; + // Like IsAll(const Literal&, int8), except we check whether the literal is + // equal to a particular complex number. + // + // If the literal is not a complex value, this always returns false. + // + // This casts value to the type of literal, then compares using ==. The usual + // admonishments about floating-point equality checks apply. We expect you to + // use this to check for complex values that can be expressed precisely as + // float pairs e.g. (-0.5, 1.0). + bool IsAllComplex(complex64 value) const; + // Returns whether this literal is zero at the specified index. This literal // must be an array. bool IsZero(tensorflow::gtl::ArraySlice indices) const; @@ -600,6 +625,7 @@ class Literal { std::vector f16s_; std::vector f32s_; std::vector f64s_; + std::vector c64s_; std::vector tuple_literals_; }; @@ -648,6 +674,10 @@ tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; template <> tensorflow::gtl::ArraySlice Literal::GetArraySlice() const; +template <> +tensorflow::gtl::ArraySlice Literal::GetArraySlice() + const; + template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); @@ -684,6 +714,9 @@ tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); template <> tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); +template <> +tensorflow::gtl::MutableArraySlice Literal::GetMutableArraySlice(); + template <> void Literal::Resize(int64 num_elements, bool value); @@ -714,6 +747,9 @@ void Literal::Resize(int64 num_elements, double value); template <> void Literal::Resize(int64 num_elements, half value); +template <> +void Literal::Resize(int64 num_elements, complex64 value); + template /* static */ std::unique_ptr Literal::CreateR0(NativeT value) { auto literal = MakeUnique(); @@ -816,33 +852,42 @@ template } template -/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( - const Array2D& values, const Layout& layout) { +/* static */ std::unique_ptr Literal::CreateFromArrayWithLayout( + const Array& values, const Layout& layout) { auto literal = MakeUnique(); - literal->PopulateR2FromArray2DWithLayout(values, layout); + literal->PopulateFromArrayWithLayout(values, layout); return literal; } +template +/* static */ std::unique_ptr Literal::CreateFromArray( + const Array& values) { + return CreateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +/* static */ std::unique_ptr Literal::CreateR2FromArray2DWithLayout( + const Array2D& values, const Layout& layout) { + return CreateFromArrayWithLayout(values, layout); +} + template /* static */ std::unique_ptr Literal::CreateR2FromArray2D( const Array2D& values) { - return CreateR2FromArray2DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR2()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3DWithLayout( const Array3D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR3FromArray3DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template /* static */ std::unique_ptr Literal::CreateR3FromArray3D( const Array3D& values) { - return CreateR3FromArray3DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR3()); + return CreateFromArray(values); } template @@ -901,16 +946,13 @@ template template /* static */ std::unique_ptr Literal::CreateR4FromArray4D( const Array4D& values) { - return CreateR4FromArray4DWithLayout(values, - LayoutUtil::GetDefaultLayoutForR4()); + return CreateFromArray(values); } template /* static */ std::unique_ptr Literal::CreateR4FromArray4DWithLayout( const Array4D& values, const Layout& layout) { - auto literal = MakeUnique(); - literal->PopulateR4FromArray4DWithLayout(values, layout); - return literal; + return CreateFromArrayWithLayout(values, layout); } template @@ -1070,82 +1112,53 @@ void Literal::PopulateR2( } template -void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, - const Layout& layout) { +void Literal::PopulateFromArrayWithLayout(const Array& values, + const Layout& layout) { *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.height(), values.width()}, AsInt64Slice(layout.minor_to_major())); + primitive_util::NativeToPrimitiveType(), values.dimensions(), + AsInt64Slice(layout.minor_to_major())); + Reserve(values.num_elements()); + values.Each([this](tensorflow::gtl::ArraySlice indices, + NativeT value) { this->Set(indices, value); }); +} - const int64 dim1_size = values.width(); - const int64 dim0_size = values.height(); - CHECK_EQ(dim0_size, shape().dimensions(0)); - CHECK_EQ(dim1_size, shape().dimensions(1)); - Reserve(dim1_size * dim0_size); - for (int64 dim0 = 0; dim0 < dim0_size; ++dim0) { - for (int64 dim1 = 0; dim1 < dim1_size; ++dim1) { - Set({dim0, dim1}, values(dim0, dim1)); - } - } +template +void Literal::PopulateFromArray(const Array& values) { + PopulateFromArrayWithLayout( + values, LayoutUtil::GetDefaultLayoutForRank(values.num_dimensions())); +} + +template +void Literal::PopulateR2FromArray2DWithLayout(const Array2D& values, + const Layout& layout) { + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR2FromArray2D(const Array2D& values) { - PopulateR2FromArray2DWithLayout(values, LayoutUtil::GetDefaultLayoutForR2()); + PopulateFromArray(values); } template void Literal::PopulateR3FromArray3DWithLayout(const Array3D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.n1(), values.n2(), values.n3()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - Reserve(values.n1() * values.n2() * values.n3()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - Set({dim0, dim1, dim2}, values(dim0, dim1, dim2)); - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR3FromArray3D(const Array3D& values) { - PopulateR3FromArray3DWithLayout(values, LayoutUtil::GetDefaultLayoutForR3()); + PopulateFromArray(values); } template void Literal::PopulateR4FromArray4DWithLayout(const Array4D& values, const Layout& layout) { - *mutable_shape() = ShapeUtil::MakeShapeWithLayout( - primitive_util::NativeToPrimitiveType(), - {values.planes(), values.depth(), values.height(), values.width()}, - AsInt64Slice(layout.minor_to_major())); - - CHECK_EQ(values.n1(), shape().dimensions(0)); - CHECK_EQ(values.n2(), shape().dimensions(1)); - CHECK_EQ(values.n3(), shape().dimensions(2)); - CHECK_EQ(values.n4(), shape().dimensions(3)); - Reserve(values.n1() * values.n2() * values.n3() * values.n4()); - for (int64 dim0 = 0; dim0 < values.n1(); ++dim0) { - for (int64 dim1 = 0; dim1 < values.n2(); ++dim1) { - for (int64 dim2 = 0; dim2 < values.n3(); ++dim2) { - for (int64 dim3 = 0; dim3 < values.n4(); ++dim3) { - Set({dim0, dim1, dim2, dim3}, values(dim0, dim1, dim2, dim3)); - } - } - } - } + PopulateFromArrayWithLayout(values, layout); } template void Literal::PopulateR4FromArray4D(const Array4D& values) { - PopulateR4FromArray4DWithLayout(values, LayoutUtil::GetDefaultLayoutForR4()); + PopulateFromArray(values); } template diff --git a/tensorflow/compiler/xla/literal_util_test.cc b/tensorflow/compiler/xla/literal_util_test.cc index e7dedd08218d8a17c5e332e5cda7bedcc26f6703..6d596da4ada82ea67c098eeb629d1e19b77dd4c4 100644 --- a/tensorflow/compiler/xla/literal_util_test.cc +++ b/tensorflow/compiler/xla/literal_util_test.cc @@ -107,6 +107,9 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) { auto f16_lit = Literal::CreateR0(static_cast(0.5f)); ASSERT_EQ("0.5", f16_lit->ToString()); + + auto c64_lit = Literal::CreateR0({3.14f, 2.78f}); + ASSERT_EQ("(3.14, 2.78)", c64_lit->ToString()); } TEST_F(LiteralUtilTest, LiteralVectorToString) { @@ -119,7 +122,7 @@ TEST_F(LiteralUtilTest, R2ToString) { const string expected = R"(s32[3,2] { { 1, 2 }, { 3, 4 }, - { 5, 6 }, + { 5, 6 } })"; ASSERT_EQ(expected, literal->ToString()); } @@ -145,8 +148,8 @@ TEST_F(LiteralUtilTest, TupleToString) { 1, f32[2,2] { { 1, 2 }, - { 3, 4 }, -}, + { 3, 4 } +} ))"; ASSERT_EQ(expected, tuple->ToString()); } @@ -188,18 +191,18 @@ TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) { EXPECT_THAT(literal->shape().dimensions(), ElementsAre(1, 2, 3, 2)); string result = literal->ToString(); const string expected = R"(f32[1,2,3,2] { - { // i0=0 - { // i1=0 + { /*i0=0*/ + { /*i1=0*/ {1, 2}, {1001, 1002}, - {2001, 2002}, + {2001, 2002} }, - { // i1=1 + { /*i1=1*/ {1, 2}, {1001, 1002}, - {2001, 2002}, - }, - }, + {2001, 2002} + } + } })"; ASSERT_EQ(expected, result); } @@ -209,30 +212,30 @@ TEST_F(LiteralUtilTest, LiteralR4F32Stringifies) { ElementsAre(2, 2, 3, 3)); string result = literal_r4_2x2x3x3_dim0major_->ToString(); const string expected = R"(f32[2,2,3,3] { - { // i0=0 - { // i1=0 + { /*i0=0*/ + { /*i1=0*/ {1, 2, 3}, {4, 5, 6}, - {7, 8, 9}, + {7, 8, 9} }, - { // i1=1 + { /*i1=1*/ {11, 12, 13}, {14, 15, 16}, - {17, 18, 19}, - }, + {17, 18, 19} + } }, - { // i0=1 - { // i1=0 + { /*i0=1*/ + { /*i1=0*/ {101, 102, 103}, {104, 105, 106}, - {107, 108, 109}, + {107, 108, 109} }, - { // i1=1 + { /*i1=1*/ {201, 202, 203}, {204, 205, 206}, - {207, 208, 209}, - }, - }, + {207, 208, 209} + } + } })"; ASSERT_EQ(expected, result); } @@ -331,6 +334,19 @@ TEST_F(LiteralUtilTest, TupleEquality) { EXPECT_NE(*tuple1, *different_tuple); } +TEST_F(LiteralUtilTest, C64Equality) { + // Test equality with tuples. + auto vector = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + + // Tuple with the same elements. One element is shared with the original + // tuple, the other is a clone of the element in the original tuple. + auto vector_clone = Literal::CreateR1({{1.0, 2.0}, {3.0, 4.0}}); + EXPECT_EQ(*vector, *vector_clone); + + auto vector_reversed = Literal::CreateR1({{3.0, 4.0}, {1.0, 2.0}}); + EXPECT_NE(*vector, *vector_reversed); +} + TEST_F(LiteralUtilTest, IsAllTuple) { auto element1 = Literal::CreateR0(0.0); auto element2 = Literal::CreateR2({{0.0, 0.0}, {0.0, 0.0}}); @@ -381,6 +397,9 @@ TEST_F(LiteralUtilTest, IsAll) { EXPECT_FALSE(Literal::CreateR2({{h8}, {h9}})->IsAll(8)); EXPECT_FALSE(Literal::CreateR2({{h9}, {h8}})->IsAll(8)); + complex64 c8_9 = {8, 9}; + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c8_9}})->IsAll(8)); + auto uint64_max = std::numeric_limits::max(); EXPECT_FALSE(Literal::CreateR2( {{uint64_max, uint64_max}, {uint64_max, uint64_max}}) @@ -411,6 +430,25 @@ TEST_F(LiteralUtilTest, IsAllFloat) { Literal::CreateR2({{0, 0, 0}, {0, .1, 0}})->IsAllFloat(0)); } +TEST_F(LiteralUtilTest, IsAllComplex) { + // IsAllComplex always returns false when the literal is not complex. + EXPECT_FALSE(Literal::CreateR0(false)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + EXPECT_FALSE(Literal::CreateR0(0)->IsAllComplex(0)); + + complex64 c8_9 = {8, 9}; + complex64 c7_9 = {7, 9}; + EXPECT_TRUE(Literal::CreateR2({{c8_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c7_9}, {c8_9}}) + ->IsAllComplex({8.0f, 9.0f})); + EXPECT_FALSE(Literal::CreateR2({{c8_9}, {c7_9}}) + ->IsAllComplex({8.0f, 9.0f})); +} + TEST_F(LiteralUtilTest, IsZero) { auto scalar_zero = Literal::CreateR0(0.0f); auto scalar_one = Literal::CreateR0(1.0f); @@ -422,12 +460,17 @@ TEST_F(LiteralUtilTest, IsZero) { EXPECT_TRUE(array->IsZero({0, 2})); EXPECT_TRUE(array->IsZero({1, 1})); EXPECT_FALSE(array->IsZero({1, 2})); + + auto complex_zero = Literal::CreateR0(0.0f); + auto complex_nonzero = Literal::CreateR0(0.5f); + EXPECT_TRUE(complex_zero->IsZero({})); + EXPECT_FALSE(complex_nonzero->IsZero({})); } template class LiteralUtilTestTemplated : public ::testing::Test {}; -using TestedTypes = ::testing::Types; +using TestedTypes = ::testing::Types; TYPED_TEST_CASE(LiteralUtilTestTemplated, TestedTypes); TYPED_TEST(LiteralUtilTestTemplated, Relayout2x2) { @@ -626,13 +669,28 @@ TEST_F(LiteralUtilTest, PopulateR1S64) { EXPECT_EQ(output, *expected); } -TEST_F(LiteralUtilTest, PopulateR2U64) { +TEST_F(LiteralUtilTest, PopulateR1U64) { Literal output; output.PopulateR1({{77, 88}}); auto expected = Literal::CreateR1({{77, 88}}); EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateR1C64) { + Literal output; + output.PopulateR1({{77, 88}}); + auto expected = Literal::CreateR1({{77, 88}}); + EXPECT_EQ(output, *expected); +} + +TEST_F(LiteralUtilTest, PopulateR2C64) { + Literal output; + output.PopulateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + auto expected = + Literal::CreateR2({{{7, 8}, {9, 10}}, {{1, 2}, {3, 4}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F32) { Literal output; output.PopulateWithValue(2.5f, {}); @@ -654,6 +712,14 @@ TEST_F(LiteralUtilTest, PopulateWithValueR2U64) { EXPECT_EQ(output, *expected); } +TEST_F(LiteralUtilTest, PopulateWithValueR2C64) { + Literal output; + output.PopulateWithValue({4, 2}, {2, 2}); + auto expected = + Literal::CreateR2({{{4, 2}, {4, 2}}, {{4, 2}, {4, 2}}}); + EXPECT_EQ(output, *expected); +} + TEST_F(LiteralUtilTest, PopulateWithValueR0F16) { Literal output; half h(0.25f); @@ -919,6 +985,11 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { {{0.0, 19.0, 0.0, 21.0}, {22.0, 0.0, 24.0, 0.0}}, {{26.0, 0.0, 28.0, 0.0}, {0.0, 31.0, 0.0, 33.0}}, }}, layout_r4_dim0major_); + auto c64 = Literal::CreateR4WithLayout({{ + {{10.0f, 0.0f, 12.0f, 0.0f}, {0.0f, 15.0f, 0.0f, 17.0f}}, + {{0.0f, 19.0f, 0.0f, 21.0f}, {22.0f, 0.0f, 24.0f, 0.0f}}, + {{26.0f, 0.0f, 28.0f, 0.0f}, {0.0f, 31.0f, 0.0f, 33.0f}}, + }}, layout_r4_dim0major_); // clang-format on std::unique_ptr conv; @@ -961,12 +1032,22 @@ TEST_F(LiteralUtilTest, ConvertIfTypesMatch) { conv = u32->Convert(F16).ConsumeValueOrDie(); EXPECT_EQ(*conv, *f16); + conv = s32->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + + conv = f16->Convert(C64).ConsumeValueOrDie(); + EXPECT_EQ(*conv, *c64); + EXPECT_EQ(s32->Convert(TUPLE).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(S16).status().code(), tensorflow::error::INVALID_ARGUMENT); EXPECT_EQ(s32->Convert(U16).status().code(), tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(F32).status().code(), + tensorflow::error::INVALID_ARGUMENT); + EXPECT_EQ(c64->Convert(S32).status().code(), + tensorflow::error::INVALID_ARGUMENT); } TEST_F(LiteralUtilTest, CopyFromProto_Bool) { diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index e4e37177a2d74e6da20300f1439942a146ad8d49..2113b5e06f3eb0169be50c0ee731a903c0eece9d 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -83,10 +83,17 @@ PrimitiveType NativeToPrimitiveType() { return F16; } +template <> +PrimitiveType NativeToPrimitiveType() { + return C64; +} + bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64; } +bool IsComplexType(PrimitiveType type) { return type == C64; } + bool IsSignedIntegralType(PrimitiveType type) { return type == S8 || type == S16 || type == S32 || type == S64; } @@ -121,6 +128,7 @@ int BitWidth(PrimitiveType type) { case U64: case S64: case F64: + case C64: return 64; case TUPLE: @@ -134,5 +142,15 @@ int BitWidth(PrimitiveType type) { } } +PrimitiveType ComplexComponentType(PrimitiveType complex_type) { + switch (complex_type) { + case C64: + return F32; + default: + LOG(FATAL) << "Primitive type is not complex: " + << PrimitiveType_Name(complex_type); + } +} + } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index 162a11c7d2966346979b98c804917203f82c806c..a49c8b86fcfe156ea3733ce05c0fb7337cf60dce 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -78,8 +78,14 @@ PrimitiveType NativeToPrimitiveType(); template <> PrimitiveType NativeToPrimitiveType(); +// Complex +template <> +PrimitiveType NativeToPrimitiveType(); + bool IsFloatingPointType(PrimitiveType type); +bool IsComplexType(PrimitiveType type); + bool IsSignedIntegralType(PrimitiveType type); bool IsUnsignedIntegralType(PrimitiveType type); @@ -89,6 +95,10 @@ bool IsIntegralType(PrimitiveType type); // Returns the number of bits in the representation for a given type. int BitWidth(PrimitiveType type); +// Returns the real, imag component type underlying the given complex type. +// LOG(FATAL)'s if complex_type is not complex. +PrimitiveType ComplexComponentType(PrimitiveType complex_type); + // Returns the native type (eg, float) corresponding to the given template // parameter XLA primitive type (eg, F32). template @@ -157,6 +167,11 @@ struct PrimitiveTypeToNative { using type = half; }; +// Complex +template <> +struct PrimitiveTypeToNative { + using type = complex64; +}; } // namespace primitive_util } // namespace xla diff --git a/tensorflow/compiler/xla/protobuf_util.cc b/tensorflow/compiler/xla/protobuf_util.cc index cdc4139cd69c3d6eb4afc2e5d25f9446ffad0a11..787725e884c810fd724ab88ad7d4beaf3e0a6cc7 100644 --- a/tensorflow/compiler/xla/protobuf_util.cc +++ b/tensorflow/compiler/xla/protobuf_util.cc @@ -37,34 +37,27 @@ bool ProtobufEquals(const tensorflow::protobuf::Message& m1, return (serialized1 == serialized2); } -StatusOr ToJson(const tensorflow::protobuf::Message& message) { - string json_output; - tensorflow::protobuf::util::JsonPrintOptions json_options; - json_options.add_whitespace = true; - json_options.always_print_primitive_fields = true; - auto status = tensorflow::protobuf::util::MessageToJsonString( - message, &json_output, json_options); - if (!status.ok()) { - return InternalError("MessageToJsonString failed: %s", - status.error_message().data()); - } - return json_output; -} - -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name) { - TF_ASSIGN_OR_RETURN(const string json_output, ToJson(message)); +namespace { - tensorflow::Env* env = tensorflow::Env::Default(); - TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); - string safe_file_name = file_name + ".json"; +string SanitizeFilename(const string& file_name) { + string safe_file_name = file_name; for (char& c : safe_file_name) { if (c == '/' || c == '\\') { c = '_'; } } + return safe_file_name; +} + +} // namespace + +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name) { + tensorflow::Env* env = tensorflow::Env::Default(); + TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory)); + string safe_file_name = SanitizeFileName(file_name) + ".pb"; const string path = tensorflow::io::JoinPath(directory, safe_file_name); - return tensorflow::WriteStringToFile(env, path, json_output); + return tensorflow::WriteBinaryProto(env, path, message); } } // namespace protobuf_util diff --git a/tensorflow/compiler/xla/protobuf_util.h b/tensorflow/compiler/xla/protobuf_util.h index 1a895c3585902e8fbc0d20475c2817ef4caa4c71..3667621367c7639c40ff17aee7b77305d4d34e33 100644 --- a/tensorflow/compiler/xla/protobuf_util.h +++ b/tensorflow/compiler/xla/protobuf_util.h @@ -32,15 +32,12 @@ namespace protobuf_util { extern bool ProtobufEquals(const tensorflow::protobuf::Message& m1, const tensorflow::protobuf::Message& m2); -// Returns 'message' as a JSON string. -StatusOr ToJson(const tensorflow::protobuf::Message& message); - -// Converts 'message' to JSON, and dumps it to the path formed by joining -// 'directory/file_name.json'. The 'directory' is recursively created if it -// doesn't already exist, and the 'file_name' is sanitized by replacing illegal -// characters with underscore '_'. -Status DumpJsonToDirectory(const tensorflow::protobuf::Message& message, - const string& directory, const string& file_name); +// Writes the given message in binary proto to the path formed by joining +// 'directory/file_name.pb'. The 'directory' is recursively created if it +// doesn't already exist, and the 'file_name' is sanitized by replacing +// illegal characters with underscore '_'. +Status DumpProtoToDirectory(const tensorflow::protobuf::Message& message, + const string& directory, const string& file_name); } // namespace protobuf_util } // namespace xla diff --git a/tensorflow/compiler/xla/reference_util_test.cc b/tensorflow/compiler/xla/reference_util_test.cc index 35b5e8cd52ab0ec21a4bd2df3e9fa0538ae60816..eb6a71242ffa1499876b90f14f8a60ffdbdd069c 100644 --- a/tensorflow/compiler/xla/reference_util_test.cc +++ b/tensorflow/compiler/xla/reference_util_test.cc @@ -322,8 +322,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithSamePadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); dimension_numbers.set_kernel_output_feature_dimension(0); @@ -374,8 +376,10 @@ TEST_F(ReferenceUtilTest, ConvGeneralDimensionsWithValidPadding) { // Set the convolution dimension numbers. ConvolutionDimensionNumbers dimension_numbers; - dimension_numbers.set_batch_dimension(2); - dimension_numbers.set_feature_dimension(0); + dimension_numbers.set_input_batch_dimension(2); + dimension_numbers.set_input_feature_dimension(0); + dimension_numbers.set_output_batch_dimension(2); + dimension_numbers.set_output_feature_dimension(0); dimension_numbers.add_spatial_dimensions(1); dimension_numbers.add_spatial_dimensions(3); diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index a88eac031baab8bb5c075806e0b86565f1a74a4c..521fe411a4beed8b075568a41bce116bb528624f 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -115,10 +115,11 @@ tf_cc_test( "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/client:computation_builder", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -130,6 +131,7 @@ cc_library( "hlo_instruction.cc", "hlo_module.cc", "hlo_opcode.cc", + "hlo_sharding.cc", ], hdrs = [ "dfs_hlo_visitor.h", @@ -138,6 +140,7 @@ cc_library( "hlo_instruction.h", "hlo_module.h", "hlo_opcode.h", + "hlo_sharding.h", ], deps = [ ":hlo_module_config", @@ -145,6 +148,7 @@ cc_library( ":hlo_reachability", ":name_uniquer", ":versioned_computation_handle", + "//tensorflow/compiler/xla:array", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", "//tensorflow/compiler/xla:shape_tree", @@ -235,6 +239,22 @@ tf_cc_test( ], ) +tf_cc_test( + name = "hlo_sharding_test", + srcs = ["hlo_sharding_test.cc"], + deps = [ + ":hlo", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:protobuf_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + ], +) + cc_library( name = "call_graph", srcs = ["call_graph.cc"], @@ -511,9 +531,9 @@ cc_library( cc_library( name = "cpu_plugin", deps = [ - ":cpu_transfer_manager", ":service", "//tensorflow/compiler/xla/service/cpu:cpu_compiler", + "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", "//tensorflow/core:stream_executor_no_cuda", ], ) @@ -521,9 +541,9 @@ cc_library( cc_library( name = "gpu_plugin", deps = [ - ":gpu_transfer_manager", ":service", "//tensorflow/compiler/xla/service/gpu:gpu_compiler", + "//tensorflow/compiler/xla/service/gpu:gpu_transfer_manager", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core/platform/default/build_config:stream_executor_cuda", ], @@ -532,9 +552,9 @@ cc_library( cc_library( name = "interpreter_plugin", deps = [ - ":interpreter_transfer_manager", ":service", "//tensorflow/compiler/xla/service/interpreter:compiler", + "//tensorflow/compiler/xla/service/interpreter:interpreter_transfer_manager", "//tensorflow/compiler/xla/service/interpreter:platform", "//tensorflow/core:stream_executor_no_cuda", ], @@ -546,6 +566,7 @@ cc_library( hdrs = ["shaped_buffer.h"], deps = [ ":device_memory_allocator", + ":transfer_manager", "//tensorflow/compiler/xla:shape_tree", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -577,12 +598,14 @@ cc_library( ":shaped_buffer", ":versioned_computation_handle", "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/compiler/xla:status", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], @@ -715,6 +738,18 @@ cc_library( ], ) +tf_cc_test( + name = "name_uniquer_test", + srcs = ["name_uniquer_test.cc"], + deps = [ + ":name_uniquer", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:test_helpers", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "liveness_util", srcs = ["liveness_util.cc"], @@ -1047,9 +1082,93 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", - "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "while_loop_simplifier", + srcs = ["while_loop_simplifier.cc"], + hdrs = ["while_loop_simplifier.h"], + deps = [ + ":call_inliner", + ":hlo", + ":hlo_evaluator", + ":hlo_pass", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "while_loop_simplifier_test", + srcs = ["while_loop_simplifier_test.cc"], + deps = [ + ":hlo_matchers", + ":while_loop_simplifier", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + "//tensorflow/core:test", + ], +) + +cc_library( + name = "defuser", + srcs = ["defuser.cc"], + hdrs = ["defuser.h"], + deps = [ + ":call_graph", + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "defuser_test", + srcs = ["defuser_test.cc"], + deps = [ + ":defuser", + ":hlo_matchers", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", + ], +) + +cc_library( + name = "tuple_simplifier", + srcs = ["tuple_simplifier.cc"], + hdrs = ["tuple_simplifier.h"], + deps = [ + ":hlo", + ":hlo_pass", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "tuple_simplifier_test", + srcs = ["tuple_simplifier_test.cc"], + deps = [ + ":hlo", + ":hlo_matchers", + ":tuple_simplifier", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:test", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/core:test", ], ) @@ -1081,7 +1200,7 @@ tf_cc_test( "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/tests:hlo_test_base", + "//tensorflow/compiler/xla/tests:hlo_verified_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", ], @@ -1172,75 +1291,17 @@ cc_library( alwayslink = True, # Contains per-platform transfer manager registration ) -cc_library( - name = "cpu_transfer_manager", - srcs = ["cpu_transfer_manager.cc"], - hdrs = ["cpu_transfer_manager.h"], - deps = [ - ":generic_transfer_manager", - ":transfer_manager", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/cpu:cpu_runtime", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], - alwayslink = True, # Contains per-platform transfer manager registration -) - -cc_library( - name = "gpu_transfer_manager", - srcs = ["gpu_transfer_manager.cc"], - hdrs = ["gpu_transfer_manager.h"], - deps = [ - ":generic_transfer_manager", - ":transfer_manager", - "//tensorflow/compiler/xla:literal_util", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla:statusor", - "//tensorflow/compiler/xla:types", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/gpu:infeed_manager", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], - alwayslink = True, # Contains per-platform transfer manager registration -) - -cc_library( - name = "interpreter_transfer_manager", - srcs = ["interpreter_transfer_manager.cc"], - hdrs = ["interpreter_transfer_manager.h"], - deps = [ - ":generic_transfer_manager", - ":transfer_manager", - "//tensorflow/compiler/xla:util", - "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service/interpreter:platform_id", - "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", - ], - alwayslink = True, # Contains per-platform transfer manager registration -) - tf_cc_test( name = "transfer_manager_test", srcs = ["transfer_manager_test.cc"], deps = [ - ":cpu_transfer_manager", ":generic_transfer_manager", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service/cpu:cpu_transfer_manager", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", "//tensorflow/core:lib", @@ -1447,6 +1508,7 @@ cc_library( ":hlo", ":hlo_buffer", ":hlo_dataflow_analysis", + ":hlo_ordering", ":hlo_value", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status", @@ -2072,6 +2134,30 @@ tf_cc_test( ], ) +cc_library( + name = "hlo_runner", + srcs = ["hlo_runner.cc"], + hdrs = ["hlo_runner.h"], + deps = [ + ":executable", + ":hlo", + ":transfer_manager", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:backend", + "//tensorflow/compiler/xla/service:compiler", + "//tensorflow/compiler/xla/tools/parser:hlo_parser", + "//tensorflow/core:core_cpu_internal", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "//third_party/eigen3", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc index f7551bfb6c777faa545e288468144a9b4922a316..35fe0d1a5192b93c0be47ecc1b1bdb753da792af 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc @@ -39,12 +39,16 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + // Returns whether operand is a literal with the given value. bool IsLiteralWithValue(const HloInstruction* operand, int8 value) { return operand->opcode() == HloOpcode::kConstant && @@ -92,11 +96,11 @@ bool ReshapeIsBitcast( HloComputation* CreateScalarBinaryComputation(HloModule* module, PrimitiveType primitive_type, HloOpcode opcode) { - HloComputation::Builder b("scalar computation"); + HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -117,71 +121,54 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleAdd(HloInstruction* add, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleAdd(HloInstruction* add) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; + Status HandleConcatenate(HloInstruction* concatenate) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; + Status HandleConstant(HloInstruction* constant) override; Status HandleCopy(HloInstruction* copy) override; Status HandleConvert(HloInstruction* convert) override; - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override; + Status HandleReal(HloInstruction* real) override; + Status HandleImag(HloInstruction* imag) override; + + Status HandleConvolution(HloInstruction* convolution) override; - Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleDivide(HloInstruction* divide) override; - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleDot(HloInstruction* dot) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; - Status HandleLog(HloInstruction* log, HloInstruction* operand) override; + Status HandleLog(HloInstruction* log) override; - Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleMultiply(HloInstruction* multiply) override; Status HandlePad(HloInstruction* pad) override; - Status HandlePower(HloInstruction* power, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandlePower(HloInstruction* power) override; Status HandleReshape(HloInstruction* reshape) override; - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override; - - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, const Window& window, - HloComputation* function) override; - - Status HandleReverse(HloInstruction* reverse, - HloInstruction* operand) override; - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - Status HandleDynamicSlice(HloInstruction* slice, HloInstruction* operand, - HloInstruction* start_indices) override; - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) override; + Status HandleReduce(HloInstruction* reduce) override; + + Status HandleReduceWindow(HloInstruction* reduce_window) override; + + Status HandleReverse(HloInstruction* reverse) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; Status HandleTranspose(HloInstruction* transpose) override; - Status HandleSubtract(HloInstruction* sub, HloInstruction* lhs, - HloInstruction* rhs) override; + Status HandleSubtract(HloInstruction* sub) override; Status HandleMaximum(HloInstruction* maximum) override; Status HandleMinimum(HloInstruction* minimum) override; @@ -193,17 +180,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { static bool Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification); + bool enable_dot_simplification, bool enable_conv_simplification); private: explicit AlgebraicSimplifierVisitor( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification) + bool enable_dot_simplification, bool enable_conv_simplification) : computation_(computation), is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification) {} + enable_dot_simplification_(enable_dot_simplification), + enable_conv_simplification_(enable_conv_simplification) {} // Convenience method for replacing an instruction with a bitcast. void ReplaceWithBitcast(HloInstruction* instruction); @@ -279,15 +267,18 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault { // Disable dot simplication on platforms where it causes a slowdown. bool enable_dot_simplification_; + + // Disable convolution simplication on platforms where it causes a slowdown. + bool enable_conv_simplification_; }; bool AlgebraicSimplifierVisitor::Run( HloComputation* computation, bool is_layout_sensitive, AlgebraicSimplifier::ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification) { - AlgebraicSimplifierVisitor visitor(computation, is_layout_sensitive, - std::move(valid_bitcast_callback), - enable_dot_simplification); + bool enable_dot_simplification, bool enable_conv_simplification) { + AlgebraicSimplifierVisitor visitor( + computation, is_layout_sensitive, std::move(valid_bitcast_callback), + enable_dot_simplification, enable_conv_simplification); TF_CHECK_OK(computation->Accept(&visitor)); return visitor.changed_; } @@ -324,9 +315,9 @@ bool AlgebraicSimplifierVisitor::ReplaceInstructionIfSameShape( return true; } -Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleAdd(HloInstruction* add) { + auto lhs = add->mutable_operand(0); + auto rhs = add->mutable_operand(1); // A + 0 => A VLOG(10) << "trying transform [A + 0 => A]: " << add->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(add, lhs)) { @@ -369,8 +360,9 @@ Status AlgebraicSimplifierVisitor::HandleCopy(HloInstruction* copy) { } Status AlgebraicSimplifierVisitor::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { + HloInstruction* concatenate) { + tensorflow::gtl::ArraySlice operands( + concatenate->operands()); if (operands.size() == 1) { // Unary concatenates are useless. ReplaceInstructionIfSameShape(concatenate, operands[0]); @@ -451,20 +443,19 @@ static HloInstruction* BuildTupleConstant(HloComputation* computation, } } -Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant, - const Literal& literal) { +Status AlgebraicSimplifierVisitor::HandleConstant(HloInstruction* constant) { // Tuple constants aren't directly supported by any backend. Expand them into // explicit Tuple instructions. if (ShapeUtil::IsTuple(constant->shape())) { - return ReplaceInstruction(constant, - BuildTupleConstant(computation_, literal)); + return ReplaceInstruction( + constant, BuildTupleConstant(computation_, constant->literal())); } return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub) { + auto lhs = sub->mutable_operand(0); + auto rhs = sub->mutable_operand(1); // A - 0 => A VLOG(10) << "trying transform [A - 0 => A]: " << sub->ToString(); if (IsAll(rhs, 0) && ReplaceInstructionIfSameShape(sub, lhs)) { @@ -474,9 +465,9 @@ Status AlgebraicSimplifierVisitor::HandleSubtract(HloInstruction* sub, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide) { + auto lhs = divide->mutable_operand(0); + auto rhs = divide->mutable_operand(1); // A/1 => A VLOG(10) << "trying transform [A/1 => A]: " << divide->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(divide, lhs)) { @@ -511,11 +502,16 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, // A/pow(B,C) => A*pow(B,-C) if (rhs->opcode() == HloOpcode::kPower) { VLOG(10) << "transform [A/pow(B,C) => A*pow(B,-C)]: " << divide->ToString(); + // The output shape of the created negate operator should be the same as the + // input. + const Shape& negate_shape = rhs->operand(1)->shape(); HloInstruction* negate = computation_->AddInstruction(HloInstruction::CreateUnary( - divide->shape(), HloOpcode::kNegate, rhs->mutable_operand(1))); + negate_shape, HloOpcode::kNegate, rhs->mutable_operand(1))); + // And the power operator should retain the output shape of the old one. + const Shape& new_power_shape = rhs->shape(); HloInstruction* new_power = computation_->AddInstruction( - HloInstruction::CreateBinary(divide->shape(), HloOpcode::kPower, + HloInstruction::CreateBinary(new_power_shape, HloOpcode::kPower, rhs->mutable_operand(0), negate)); return ReplaceWithNewInstruction( divide, HloInstruction::CreateBinary( @@ -578,9 +574,9 @@ Status AlgebraicSimplifierVisitor::HandleDivide(HloInstruction* divide, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) { + auto lhs = dot->mutable_operand(0); + auto rhs = dot->mutable_operand(1); if (!enable_dot_simplification_) { return Status::OK(); } @@ -709,9 +705,9 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply) { + auto lhs = multiply->mutable_operand(0); + auto rhs = multiply->mutable_operand(1); // A*1 => A VLOG(10) << "trying transform [A*1 => A]: " << multiply->ToString(); if (IsAll(rhs, 1) && ReplaceInstructionIfSameShape(multiply, lhs)) { @@ -735,10 +731,10 @@ Status AlgebraicSimplifierVisitor::HandleMultiply(HloInstruction* multiply, return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log) { // ln(exp(A)) => A VLOG(10) << "trying transform [ln(exp(A)) => A]: " << log->ToString(); + auto operand = log->mutable_operand(0); if (operand->opcode() == HloOpcode::kExp && ReplaceInstructionIfSameShape(log, operand->mutable_operand(0))) { return Status::OK(); @@ -758,7 +754,8 @@ Status AlgebraicSimplifierVisitor::HandleLog(HloInstruction* log, } Status AlgebraicSimplifierVisitor::HandleGetTupleElement( - HloInstruction* get_tuple_element, HloInstruction* operand) { + HloInstruction* get_tuple_element) { + auto operand = get_tuple_element->mutable_operand(0); if (operand->opcode() == HloOpcode::kTuple) { // get_tuple_element(make_tuple({A_0, A_1, ..., A_n}), i) => A_i VLOG(10) << "trying transform " @@ -904,9 +901,10 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { // A Broadcast that feeds a unary element-wise operation can sink the // broadcast after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(broadcast)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -926,11 +924,11 @@ Status AlgebraicSimplifierVisitor::HandleBroadcast(HloInstruction* broadcast) { << "a single broadcast"; HloInstruction* new_broadcast = computation_->AddInstruction( HloInstruction::CreateBroadcast(user->shape(), operand, {})); - // Use ReplaceUsesOfInstruction instead of ReplaceWithNewInstruction - // because we are replacing an instruction other than the visited - // instruction. + // Use HloInstruction::ReplaceAllUsesWith instead of + // HloComputation::ReplaceWithNewInstruction because we are replacing an + // instruction other than the visited instruction. changed_ = true; - return computation_->ReplaceUsesOfInstruction(user, new_broadcast); + return user->ReplaceAllUsesWith(new_broadcast); } } } @@ -949,6 +947,24 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) { return Status::OK(); } +// Real(Complex(r, i)) -> r +Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) { + auto operand = real->mutable_operand(0); + if (operand->opcode() == HloOpcode::kComplex) { + return ReplaceInstruction(real, operand->mutable_operand(0)); + } + return Status::OK(); +} + +// Imag(Complex(r, i)) -> i +Status AlgebraicSimplifierVisitor::HandleImag(HloInstruction* imag) { + auto operand = imag->mutable_operand(0); + if (operand->opcode() == HloOpcode::kComplex) { + return ReplaceInstruction(imag, operand->mutable_operand(1)); + } + return Status::OK(); +} + Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { // Eliminate nop pads (padding all zero), and replace a pad with negative // padding with a pad with non-negative padding followed by a slice. @@ -1039,10 +1055,10 @@ Status AlgebraicSimplifierVisitor::HandlePad(HloInstruction* pad) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power, - HloInstruction* lhs, - HloInstruction* rhs) { +Status AlgebraicSimplifierVisitor::HandlePower(HloInstruction* power) { VLOG(10) << "trying transform [pow(A, 0) => 1]: " << power->ToString(); + auto lhs = power->mutable_operand(0); + auto rhs = power->mutable_operand(1); if (IsAll(rhs, 0)) { auto one = HloInstruction::CreateConstant( Literal::One(power->shape().element_type()).CloneToUnique()); @@ -1163,8 +1179,7 @@ StatusOr AlgebraicSimplifierVisitor:: } VLOG(4) << " new reshape/broadcast: " << new_reshape_or_broadcast->ToString(); - TF_RETURN_IF_ERROR( - computation_->ReplaceUsesOfInstruction(user, new_reshape_or_broadcast)); + TF_RETURN_IF_ERROR(user->ReplaceAllUsesWith(new_reshape_or_broadcast)); changed = true; } return changed; @@ -1210,9 +1225,10 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { // A Reshape that feeds a unary element-wise operation can sink the // reshape after the unary element-wise operation. TF_ASSIGN_OR_RETURN( - changed_, + bool sink_succeeded, TryToSinkReshapeOrBroadcastAfterOpWithUniqueNonScalarOperand(reshape)); - if (changed_) { + changed_ |= sink_succeeded; + if (sink_succeeded) { return Status::OK(); } @@ -1226,8 +1242,7 @@ Status AlgebraicSimplifierVisitor::HandleReshape(HloInstruction* reshape) { return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse) { // When all the dimensions to reverse are trivial (i.e. the bound is 1), // there is nothing to be done. auto dim_is_one = [&](int64 i) -> bool { @@ -1235,42 +1250,61 @@ Status AlgebraicSimplifierVisitor::HandleReverse(HloInstruction* reverse, }; if (std::all_of(reverse->dimensions().begin(), reverse->dimensions().end(), dim_is_one)) { - return ReplaceInstruction(reverse, operand); + return ReplaceInstruction(reverse, reverse->mutable_operand(0)); } return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice, - HloInstruction* operand) { +Status AlgebraicSimplifierVisitor::HandleSlice(HloInstruction* slice) { // Delete no-op slices, i.e. where shape = operand shape. - if (ReplaceInstructionIfSameShape(slice, operand)) { + if (ReplaceInstructionIfSameShape(slice, slice->mutable_operand(0))) { return Status::OK(); } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleDynamicSlice( - HloInstruction* dynamic_slice, HloInstruction* operand, - HloInstruction* start_indices) { + HloInstruction* dynamic_slice) { + auto operand = dynamic_slice->mutable_operand(0); + auto start_indices = dynamic_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_slice->shape())) { return ReplaceInstruction(dynamic_slice, operand); } + // DynamicSlice where operand has the same size as the output and + // start_indices are all zero is simply equal to operand. + if (IsAll(start_indices, 0) && SameShape(operand, dynamic_slice)) { + return ReplaceInstruction(dynamic_slice, operand); + } return Status::OK(); } Status AlgebraicSimplifierVisitor::HandleDynamicUpdateSlice( - HloInstruction* dynamic_update_slice, HloInstruction* operand, - HloInstruction* update, HloInstruction* start_indices) { + HloInstruction* dynamic_update_slice) { + auto update = dynamic_update_slice->mutable_operand(1); + auto start_indices = dynamic_update_slice->operand(2); // DynamicUpdateSlice on a scalar just passes through the update argument. if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { return ReplaceInstruction(dynamic_update_slice, update); } + + // DynamicUpdateSlice where operand and update have the same size and + // start_indices are all zero is simply equal to update. + // + // (We require start_indices to be all zero because we want this optimization + // not to affect the visible behavior of this op even when the indices are out + // of range. Currently dynamic-update-slice wraps out-of-range indices, so + // we can only remove the op if its indices never wrap.) + if (IsAll(start_indices, 0) && SameShape(dynamic_update_slice, update)) { + return ReplaceInstruction(dynamic_update_slice, update); + } return Status::OK(); } -Status AlgebraicSimplifierVisitor::HandleReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { +Status AlgebraicSimplifierVisitor::HandleReduce(HloInstruction* reduce) { + auto arg = reduce->mutable_operand(0); + auto init_value = reduce->mutable_operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); if (ShapeUtil::HasZeroElements(arg->shape()) || ShapeUtil::HasZeroElements(reduce->shape())) { return ReplaceWithNewInstruction( @@ -1348,8 +1382,10 @@ Status AlgebraicSimplifierVisitor::HandleReduce( } Status AlgebraicSimplifierVisitor::HandleReduceWindow( - HloInstruction* reduce_window, HloInstruction* operand, - const Window& window, HloComputation* function) { + HloInstruction* reduce_window) { + auto operand = reduce_window->mutable_operand(0); + const Window& window = reduce_window->window(); + auto function = reduce_window->to_apply(); VLOG(10) << "Considering folding Pad: " << operand->ToString() << "\ninto reduce-window: " << reduce_window->ToString(); @@ -1432,8 +1468,13 @@ Status AlgebraicSimplifierVisitor::HandleTranspose(HloInstruction* transpose) { } Status AlgebraicSimplifierVisitor::HandleConvolution( - HloInstruction* convolution, HloInstruction* lhs, HloInstruction* rhs, - const Window& window) { + HloInstruction* convolution) { + auto lhs = convolution->mutable_operand(0); + auto rhs = convolution->mutable_operand(1); + const auto& window = convolution->window(); + if (!enable_conv_simplification_) { + return Status::OK(); + } // HandleConvolution tries to replace a convolution with a DOT instruction. // // Only add when bitcasts can be used: @@ -1486,7 +1527,10 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // still convert Conv into more efficient Matmul with operand transposition // (such as the transposition flags in cuBLAS SGEMM). if (!LayoutUtil::Equal(input_shape.layout(), convolution_shape.layout()) || - input_shape.layout().minor_to_major(0) != dnums.feature_dimension() || + input_shape.layout().minor_to_major(0) != + dnums.input_feature_dimension() || + convolution_shape.layout().minor_to_major(0) != + dnums.output_feature_dimension() || // The input feature dimension should come later in the minor-to-major // order. (PositionInContainer(filter_shape.layout().minor_to_major(), @@ -1505,14 +1549,14 @@ Status AlgebraicSimplifierVisitor::HandleConvolution( // Replace it with a dot, with bitcasts around it to get the right shape. const int64 input_channels = - input_shape.dimensions(dnums.feature_dimension()); + input_shape.dimensions(dnums.input_feature_dimension()); const int64 output_channels = filter_shape.dimensions(dnums.kernel_output_feature_dimension()); // Computes the product of the non-feature dimensions. int64 conv_width = 1; for (int i = 0; i < input_shape.dimensions_size(); ++i) { - if (i != dnums.feature_dimension()) { + if (i != dnums.input_feature_dimension()) { conv_width *= input_shape.dimensions(i); } } @@ -1629,19 +1673,10 @@ StatusOr AlgebraicSimplifier::Run(HloModule* module) { XLA_VLOG_LINES(2, "AlgebraicSimplifier::Run(), before:\n" + module->ToString()); bool changed = false; - // Make a copy of the computations because we may add computations to the - // module, invalidating iteration. - std::vector computations; - for (auto& comp : module->computations()) { - if (comp->IsFusionComputation()) { - continue; - } - computations.push_back(comp.get()); - } - for (auto& comp : computations) { - if (AlgebraicSimplifierVisitor::Run(comp, is_layout_sensitive_, - valid_bitcast_callback_, - enable_dot_simplification_)) { + for (auto* comp : module->MakeNonfusionComputations()) { + if (AlgebraicSimplifierVisitor::Run( + comp, is_layout_sensitive_, valid_bitcast_callback_, + enable_dot_simplification_, enable_conv_simplification_)) { changed = true; } } diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.h b/tensorflow/compiler/xla/service/algebraic_simplifier.h index 4295a3227a837ffc8483b3be59994c9e6ac96aec..a9f476178c7af74c275a10de7727ea64e17d590f 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier.h +++ b/tensorflow/compiler/xla/service/algebraic_simplifier.h @@ -40,11 +40,13 @@ class AlgebraicSimplifier : public HloPassInterface { // bitcasts. AlgebraicSimplifier(bool is_layout_sensitive, ValidBitcastCallback valid_bitcast_callback, - bool enable_dot_simplification = true) + bool enable_dot_simplification = true, + bool enable_conv_simplification = true) : is_layout_sensitive_(is_layout_sensitive), valid_bitcast_callback_(std::move(valid_bitcast_callback)), - enable_dot_simplification_(enable_dot_simplification) {} - ~AlgebraicSimplifier() override {} + enable_dot_simplification_(enable_dot_simplification), + enable_conv_simplification_(enable_conv_simplification) {} + ~AlgebraicSimplifier() override = default; tensorflow::StringPiece name() const override { return "algsimp"; } // Run algebraic simplification on the given computation. Returns whether the @@ -57,6 +59,9 @@ class AlgebraicSimplifier : public HloPassInterface { // Enable dot simplication on platforms where it is profitable. bool enable_dot_simplification_; + + // Enable convolution simplication on platforms where it is profitable. + bool enable_conv_simplification_; }; } // namespace xla diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc index f968ec693f5789647f5bbc3892f933d51d177c09..c06e330bc12ec73ae46b84505b34c16e3591aaa5 100644 --- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc +++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc @@ -28,16 +28,17 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_pass_fix.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/lib/strings/str_util.h" -namespace op = xla::testing::opcode_matchers; - namespace xla { namespace { +namespace op = xla::testing::opcode_matchers; + AlgebraicSimplifier::ValidBitcastCallback bitcasting_callback() { return [](const Shape&, const Shape&) { return true; }; } @@ -46,7 +47,7 @@ AlgebraicSimplifier::ValidBitcastCallback non_bitcasting_callback() { return [](const Shape&, const Shape&) { return false; }; } -using AlgebraicSimplifierTest = HloTestBase; +class AlgebraicSimplifierTest : public HloVerifiedTestBase {}; // Test that A + 0 is simplified to A TEST_F(AlgebraicSimplifierTest, AddZero) { @@ -290,6 +291,42 @@ TEST_F(AlgebraicSimplifierTest, DivOfPower) { op::Multiply(param0, op::Power(param1, op::Negate(param2)))); } +// Test that broadcasting is done on the right step when simplifying A/pow(B,C) +// to A*pow(B,-C). +TEST_F(AlgebraicSimplifierTest, DivOfBroadcastingPower) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + Shape r1f32 = ShapeUtil::MakeShape(F32, {7}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r1f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r1f32, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, r0f32, "param2")); + HloInstruction* power = builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kPower, param1, param2)); + builder.AddInstruction( + HloInstruction::CreateBinary(r1f32, HloOpcode::kDivide, param0, power)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Divide(param0, op::Power(param1, param2))); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + + ASSERT_THAT(computation->root_instruction(), + op::Multiply(param0, op::Power(param1, op::Negate(param2)))); + + const HloInstruction* negate = + computation->root_instruction()->operand(1)->operand(1); + const Shape& negate_shape = negate->shape(); + EXPECT_EQ(0, negate_shape.dimensions_size()); +} + // Test that A/1 is simplified to A for a scalar. TEST_F(AlgebraicSimplifierTest, DivOneScalar) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -334,6 +371,56 @@ TEST_F(AlgebraicSimplifierTest, DivOneArray) { EXPECT_EQ(root, param0); } +// Test that real(complex(r,i)) is simplified to r. +TEST_F(AlgebraicSimplifierTest, RealOfComplex) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), + HloOpcode::kComplex, param0, param1)); + HloInstruction* real = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kReal, cplx)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, real); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param0); +} + +// Test that imag(complex(r,i)) is simplified to i. +TEST_F(AlgebraicSimplifierTest, ImagOfComplex) { + Shape r2f32 = ShapeUtil::MakeShape(F32, {2, 2}); + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, r2f32, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, r2f32, "param1")); + HloInstruction* cplx = builder.AddInstruction( + HloInstruction::CreateBinary(ShapeUtil::ChangeElementType(r2f32, C64), + HloOpcode::kComplex, param0, param1)); + HloInstruction* imag = builder.AddInstruction( + HloInstruction::CreateUnary(r2f32, HloOpcode::kImag, cplx)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + HloInstruction* root = computation->root_instruction(); + EXPECT_EQ(root, imag); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + root = computation->root_instruction(); + EXPECT_EQ(root, param1); +} + // Test that get_element(make_tuple({A,B}),1) is simplified to B TEST_F(AlgebraicSimplifierTest, SelectMakeTuple) { Shape r0f32 = ShapeUtil::MakeShape(F32, {}); @@ -1014,6 +1101,54 @@ TEST_F(AlgebraicSimplifierTest, ReshapeToScalarNotHoistedAfterEffectiveUnary) { op::Maximum(op::Reshape(param), zero)); } +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkReshapeDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction( + HloInstruction::CreateReshape(ShapeUtil::MakeShape(F32, {4}), add)); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + +// Regression test for a bug where if we failed to sink a reshape, we'd set the +// 'changed' bit in AlgebraicSimplifier to false. +TEST_F(AlgebraicSimplifierTest, FailureToSinkBroadcastDoesntAffectChangedBit) { + HloComputation::Builder builder(TestName()); + + // This add (param0 + 0) can be simplified. + Shape shape = ShapeUtil::MakeShape(F32, {2, 2}); + HloInstruction* add = builder.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kAdd, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "param0")), + builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{0, 0}, {0, 0}}))))); + + builder.AddInstruction(HloInstruction::CreateBroadcast( + ShapeUtil::MakeShape(F32, {2, 2, 2}), add, /*broadcast_dimensions=*/{0})); + + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + bitcasting_callback()); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + EXPECT_TRUE(simplifier.Run(module.get()).ValueOrDie()); +} + TEST_F(AlgebraicSimplifierTest, TransposeEqualsBitcast1) { HloComputation::Builder builder(TestName()); HloInstruction* param = @@ -1467,7 +1602,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { for (int i = 0; i < strlen(options.dim_order); ++i) { char ch = options.dim_order[i]; if (ch == 'N') { - dnums.set_batch_dimension(i); + dnums.set_input_batch_dimension(i); + dnums.set_output_batch_dimension(i); in_dims.push_back(options.in_batch); } else if (ch == 'H') { dnums.set_spatial_dimensions(0, i); @@ -1476,7 +1612,8 @@ TEST_F(AlgebraicSimplifierTest, ConvertConvToMatmul) { dnums.set_spatial_dimensions(1, i); in_dims.push_back(options.in_width); } else if (ch == 'C') { - dnums.set_feature_dimension(i); + dnums.set_input_feature_dimension(i); + dnums.set_output_feature_dimension(i); in_dims.push_back(options.in_channels); in_channel_idx = i; } @@ -1978,7 +2115,7 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) { HloInstruction::CreateConstant(Literal::CreateR1({0.0f}))); HloInstruction* one = call_builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR1({1.0f}))); - builder.AddInstruction( + call_builder.AddInstruction( HloInstruction::CreateCall(r1f32, {zero, one}, dot_computation.get())); auto module = CreateNewModule(); @@ -2009,5 +2146,63 @@ TEST_F(AlgebraicSimplifierTest, ConstantTupleBecomesTupleOfConstants) { op::Tuple(op::Constant(), op::Constant())); } +// A dynamic-slice is trivial if its start indices are all zeroes and the size +// of its input equals the size of its output. In this case, the dynamic slice +// is equal to its input. +TEST_F(AlgebraicSimplifierTest, TrivialDynamicSlice) { + HloComputation::Builder builder(TestName()); + + Shape shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, shape, "slice_from")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))), + /*slice_sizes=*/{10, 100, 1000})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), op::Parameter()); +} + +// A dynamic-update-slice is trivial if its start indices are all zeroes and the +// size of its "update" equals the size of its output. In this case, the +// dynamic-update-slice is equal to its update. +TEST_F(AlgebraicSimplifierTest, TrivialDynamicUpdateSlice) { + HloComputation::Builder builder(TestName()); + + Shape full_shape = ShapeUtil::MakeShape(F32, {10, 100, 1000}); + Shape slice_shape = ShapeUtil::MakeShape(F32, {10, 1, 1000}); + + HloInstruction* slice = + builder.AddInstruction(HloInstruction::CreateDynamicSlice( + slice_shape, + builder.AddInstruction( + HloInstruction::CreateParameter(0, full_shape, "slice_from")), + builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(U32, {3}), "slice_indices")), + /*slice_sizes=*/{10, 1, 1000})); + + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + slice_shape, + builder.AddInstruction( + HloInstruction::CreateParameter(2, slice_shape, "to_update")), + slice, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 0, 0}))))); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + AlgebraicSimplifier simplifier(/*is_layout_sensitive=*/false, + non_bitcasting_callback()); + ASSERT_TRUE(simplifier.Run(module.get()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::DynamicSlice(op::Parameter(), op::Parameter())); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc index 41d32d0c8b1cc31522f2c8012fb5350816cadbec..abe881cd1a58a6173b9b93f10a7308d70106c889 100644 --- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc +++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc @@ -83,11 +83,11 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault { HloComputation* GetScalarBinaryComputation(PrimitiveType primitive_type, HloOpcode opcode) { - HloComputation::Builder b("scalar computation"); + HloComputation::Builder b("scalar_computation"); auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(F32, {}), "scalar lhs")); + 0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs")); auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {}), "scalar rhs")); + 1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs")); auto scalar_op = b.AddInstruction( HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}), opcode, scalar_lhs, scalar_rhs)); @@ -531,16 +531,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad( StatusOr BatchNormRewriter::Run(HloModule* module) { XLA_VLOG_LINES(2, "BatchNormRewriter::Run(), before:\n" + module->ToString()); bool changed = false; - // Make a copy of the computations because we may add computations to the - // module, invalidating iteration. - std::vector computations; - for (auto& comp : module->computations()) { - if (comp->IsFusionComputation()) { - continue; - } - computations.push_back(comp.get()); - } - for (auto& comp : computations) { + for (auto* comp : module->MakeNonfusionComputations()) { if (BatchNormRewriterVisitor::Run(comp, rewrite_training_op_, rewrite_inference_op_, rewrite_grad_op_, use_fusion_)) { diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index 6bc0ca4f827b44c78336100b5380ac4c86e8df01..b422b22df9cfbefb6611fcb229ed42e67fe3a0d8 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -101,6 +101,11 @@ BufferAllocationProto BufferAllocation::ToProto() const { proto_assigned->set_offset(buffer_offset_size.second.offset); proto_assigned->set_size(buffer_offset_size.second.size); } + std::sort(proto.mutable_assigned()->begin(), proto.mutable_assigned()->end(), + [](const BufferAllocationProto::Assigned& assign1, + const BufferAllocationProto::Assigned& assign2) { + return assign1.logical_buffer_id() < assign2.logical_buffer_id(); + }); return proto; } @@ -388,10 +393,10 @@ Status BufferAssignment::ComputeSummaryStats() { const std::vector* sequence = liveness_->hlo_ordering().SequentialOrder(*computation); if (sequence != nullptr) { - module_sequence.emplace(computation.get(), *sequence); + module_sequence.emplace(computation, *sequence); } } - if (module_sequence.size() == module_->computations().size()) { + if (module_sequence.size() == module_->computation_count()) { TF_ASSIGN_OR_RETURN( const int64 min_size, MinimumMemoryForSequence(module_sequence, buffer_size_)); @@ -535,7 +540,7 @@ Status GatherComputationsByAllocationType( global_set.insert(computation); } - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* subcomputation : instruction->called_computations()) { switch (instruction->opcode()) { @@ -688,13 +693,13 @@ Status BufferAssigner::AssignBuffersForComputation( // Buffers are sorted and assigned to BufferAllocations in decreasing order of // size. std::vector sorted_buffers; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { // Add all buffers which this instruction defines. Instruction which don't // define buffers (eg, bitcast which just forwards a pointer) don't need // any allocations. for (const LogicalBuffer* buffer : assignment->points_to_analysis().GetBuffersDefinedByInstruction( - instruction.get())) { + instruction)) { sorted_buffers.push_back(buffer); } } @@ -1121,6 +1126,7 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets( // Scan 'colocated_buffer_sets' in reverse order for locality; colocated sets // are added in postorder over computations and instructions. const int64 init_buffer_size = buffer_size(*while_init_buffer); + const bool is_live_out = buffer_liveness.MaybeLiveOut(*while_result_buffer); for (int i = colocated_buffer_sets->size() - 1; i >= 0; --i) { const ColocatedBufferSet& predecessor_set = (*colocated_buffer_sets)[i]; @@ -1141,6 +1147,20 @@ void BufferAssigner::AddWhileSetToColocatedBufferSets( continue; } + // Skip predecessor sets with entry parameter if the while result is live + // out. + if (is_live_out && + std::any_of(predecessor_set.begin(), predecessor_set.end(), + [](const LogicalBuffer* buffer) { + auto* instruction = buffer->instruction(); + auto* computation = instruction->parent(); + auto* module = computation->parent(); + return instruction->opcode() == HloOpcode::kParameter && + computation == module->entry_computation(); + })) { + continue; + } + // Build vector of predecessor while result and init buffers, which are // checked for liveness interference below. We must check both the result // and init buffers because they're aliased together, but diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h index 688aff89125ce3e30be8918a9dfe9f17e22e6243..08a53af8baa3f250919517c87c023c329b129024 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.h +++ b/tensorflow/compiler/xla/service/buffer_assignment.h @@ -320,6 +320,13 @@ class BufferAssignment { const HloInstruction* hlo_b, const ShapeIndex& shape_index_b) const; + // Returns true if the top-level buffers of hlo_a and hlo_b are the same. + // REQUIRES: HasTopLevelAllocation(hlo_a) && HasTopLevelAllocation(hlo_b). + bool SharesTopLevelSlice(const HloInstruction* hlo_a, + const HloInstruction* hlo_b) const { + return SharesSliceAtIndex(hlo_a, {}, hlo_b, {}); + } + // Returns the underlying points-to analysis used for this assignment. const TuplePointsToAnalysis& points_to_analysis() const { return liveness_->points_to_analysis(); diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc index ca07a02814b43edec90691d1b145357ba4323254..89410f42bd7b5fa8f9b380c868fcd4fedb54576c 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc @@ -1179,7 +1179,7 @@ TEST_F(BufferAssignmentTest, TupleCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); EXPECT_EQ(3, assignment->Allocations().size()); - // Buffers for call are co-located with the sub-computation. + // Buffers for call are colocated with the sub-computation. EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{}), GetAllocation(*assignment, sub_tuple, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, call, /*index=*/{0}), @@ -1238,7 +1238,7 @@ TEST_F(BufferAssignmentTest, TupleChainedCallAsOutput) { auto assignment = RunBufferAssignment(module.get()); - // Buffers for call are co-located with the sub-computations. + // Buffers for call are colocated with the sub-computations. EXPECT_EQ(GetAllocation(*assignment, a_call, /*index=*/{}), GetAllocation(*assignment, b_call, /*index=*/{})); EXPECT_EQ(GetAllocation(*assignment, b_call, /*index=*/{}), @@ -1764,5 +1764,62 @@ TEST_F(WhileBufferAssignmentTest, DISABLED_TwoWhiles) { EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment)); } +TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) { + auto module = MakeUnique(TestName()); + auto builder = HloComputation::Builder("entry"); + + auto input0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, data_shape_, "input0")); + auto weights0 = builder.AddInstruction( + HloInstruction::CreateParameter(1, data_shape_, "weights0")); + + auto zero = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0.0))); + auto output0 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + auto output1 = builder.AddInstruction( + HloInstruction::CreateBroadcast(data_shape_, zero, {1})); + + auto cond0 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body0 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple0 = builder.AddInstruction( + HloInstruction::CreateTuple({input0, weights0, output0})); + auto while0 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond0, body0, tuple0)); + + // Get output of 'while0' and feed as input to 'while1'. + auto while0_out = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while0, 2)); + + auto cond1 = + module->AddEmbeddedComputation(BuildWhileConditionComputation("cond")); + auto body1 = + module->AddEmbeddedComputation(BuildWhileBodyComputation("body")); + + auto tuple1 = builder.AddInstruction( + HloInstruction::CreateTuple({while0_out, weights0, output1})); + auto while1 = builder.AddInstruction( + HloInstruction::CreateWhile(loop_state_shape_, cond1, body1, tuple1)); + + // Get output of 'while1' so that it is live out of computation. + auto while1_out = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(data_shape_, while1, 2)); + + module->AddEntryComputation(builder.Build()); + RunCopyInsertion(module.get()); + auto assignment = RunBufferAssignment(module.get()); + // Get BufferAllocation for root instruction. + auto* root_alloc = assignment->GetUniqueTopLevelSlice(while1_out) + .ConsumeValueOrDie() + .allocation(); + // Test that root instruction allocation is live out. + EXPECT_TRUE(root_alloc->maybe_live_out()); + // Test that root instruction allocation is not an entry parameter. + EXPECT_FALSE(root_alloc->is_entry_computation_parameter()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/buffer_liveness.cc b/tensorflow/compiler/xla/service/buffer_liveness.cc index 8610080203760d576e10df5b7d1610041c6d9b8e..513bfa3b7f7b45696093d03c1dd8250c548d260a 100644 --- a/tensorflow/compiler/xla/service/buffer_liveness.cc +++ b/tensorflow/compiler/xla/service/buffer_liveness.cc @@ -46,7 +46,7 @@ StatusOr> BufferLiveness::Run( tensorflow::Status BufferLiveness::Analyze() { TF_ASSIGN_OR_RETURN(points_to_analysis_, TuplePointsToAnalysis::Run(module_)); - for (auto& computation : module_->computations()) { + for (auto* computation : module_->computations()) { if (computation->IsFusionComputation()) { continue; } @@ -55,15 +55,15 @@ tensorflow::Status BufferLiveness::Analyze() { // element in other instruction's output. for (const auto& instruction : computation->instructions()) { for (const LogicalBuffer* aliased_buffer : - points_to_analysis_->GetPointsToSet(instruction.get()) + points_to_analysis_->GetPointsToSet(instruction) .CreateFlattenedSet()) { - if (aliased_buffer->instruction() != instruction.get()) { + if (aliased_buffer->instruction() != instruction) { aliased_buffers_.insert(aliased_buffer); } } } - if (computation.get() == module_->entry_computation()) { + if (computation == module_->entry_computation()) { const HloInstruction* root = computation->root_instruction(); maybe_live_out_buffers_ = points_to_analysis_->GetPointsToSet(root).CreateFlattenedSet(); diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index c0f3bcdc2218199288eaa3d0010ee70632c8f959..1adecdb939cb2c1259003d3be2c90b5a299b0f30 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -189,9 +189,8 @@ void CallGraph::SetCallContexts() { // Initialize worklist with all roots of the call graph (computations without // callers). - for (const std::unique_ptr& computation : - module_->computations()) { - CallGraphNode& node = GetNode(computation.get()); + for (const HloComputation* computation : module_->computations()) { + CallGraphNode& node = GetNode(computation); if (node.callers().empty()) { node.set_context(CallContext::kSequential); worklist.push(&node); @@ -228,9 +227,8 @@ void CallGraph::SetCallContexts() { } // No node should have a kNone calling context. - for (const std::unique_ptr& computation : - module_->computations()) { - CHECK_NE(GetNode(computation.get()).context(), CallContext::kNone); + for (const HloComputation* computation : module_->computations()) { + CHECK_NE(GetNode(computation).context(), CallContext::kNone); } } @@ -243,27 +241,24 @@ std::unique_ptr CallGraph::Build(const HloModule* module) { XLA_VLOG_LINES(2, module->ToString()); // Construct nodes of the call graph and populate the callsites. - for (const std::unique_ptr& computation : - module->computations()) { + for (HloComputation* computation : module->computations()) { auto it_added = call_graph->node_indices_.insert( - {computation.get(), call_graph->nodes_.size()}); + {computation, call_graph->nodes_.size()}); // All computations should be unique, so the computation should not already // exist in the map. CHECK(it_added.second); - call_graph->nodes_.emplace_back(computation.get()); + call_graph->nodes_.emplace_back(computation); // Add all callsites in this computation. - for (const std::unique_ptr& instruction : - computation->instructions()) { - call_graph->nodes_.back().AddCallSiteForInstruction(instruction.get()); + for (HloInstruction* instruction : computation->instructions()) { + call_graph->nodes_.back().AddCallSiteForInstruction(instruction); } } // Add caller callsites to each node. - for (const std::unique_ptr& computation : - module->computations()) { + for (const HloComputation* computation : module->computations()) { for (const CallSite& callsite : - call_graph->GetNode(computation.get()).callsites()) { + call_graph->GetNode(computation).callsites()) { for (auto* callee : callsite.called_computations()) { // Add caller callsites. call_graph->GetNode(callee).AddCallerCallSite(callsite); diff --git a/tensorflow/compiler/xla/service/call_inliner.cc b/tensorflow/compiler/xla/service/call_inliner.cc index 65472d9ac92416859afc88b43eb150ab8730fc2d..3aa7f5c4d5829ccc0e8df697c1363754128ff436 100644 --- a/tensorflow/compiler/xla/service/call_inliner.cc +++ b/tensorflow/compiler/xla/service/call_inliner.cc @@ -26,8 +26,7 @@ namespace { // Traverses the callee computation, inlining cloned nodes into the caller // computation and connecting them to producers/consumers appropriately. // When the traversal has completed, the provided call instruction is entriely -// replaced in the caller's graph, and any calls encountered in the callee -// computation have been added to the work_queue. +// replaced in the caller's graph. class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { public: // call is the call operation -- it will be replaced with the body of the @@ -79,6 +78,7 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root)); VLOG(1) << "Replacing all uses of " << call_->ToString() << " with new root " << new_root->ToString(); + call_->ClearCalledComputations(); return outer_->ReplaceInstruction(call_, new_root); } @@ -114,11 +114,21 @@ class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault { HloComputation* outer_; std::unordered_map subcomputation_hlo_to_new_hlo_; - std::deque* work_queue_; }; } // namespace +/* static */ Status CallInliner::Inline(HloInstruction* call) { + TF_RET_CHECK(call->opcode() == HloOpcode::kCall) + << "Instruction was not a call op: " << call->opcode(); + const auto& callees = call->called_computations(); + TF_RET_CHECK(callees.size() == 1); + HloComputation* callee = callees[0]; + // We visit the callee, cloning its body into its caller. + SubcomputationInsertionVisitor visitor(call); + return callee->Accept(&visitor); +} + StatusOr CallInliner::Run(HloModule* module) { std::unique_ptr call_graph = CallGraph::Build(module); // Because call graph nodes are visited in post-order (callees before callers) @@ -129,13 +139,9 @@ StatusOr CallInliner::Run(HloModule* module) { for (const CallSite& callsite : node.caller_callsites()) { VLOG(1) << "Visiting callsite: " << callsite.ToString(); if (callsite.instruction()->opcode() == HloOpcode::kCall) { + HloInstruction* call = callsite.instruction(); + TF_RETURN_IF_ERROR(Inline(call)); did_mutate = true; - const auto& callees = callsite.called_computations(); - TF_RET_CHECK(callees.size() == 1); - HloComputation* callee = callees[0]; - // We visit the callee, cloning its body into its caller. - SubcomputationInsertionVisitor visitor(callsite.instruction()); - TF_RETURN_IF_ERROR(callee->Accept(&visitor)); } } return Status::OK(); diff --git a/tensorflow/compiler/xla/service/call_inliner.h b/tensorflow/compiler/xla/service/call_inliner.h index 8660200bc405dac93daf8f41443290f4fa089bd8..2dbd38bf1ac90d3efa1453e6af6f791668d5e72a 100644 --- a/tensorflow/compiler/xla/service/call_inliner.h +++ b/tensorflow/compiler/xla/service/call_inliner.h @@ -27,6 +27,9 @@ namespace xla { // called function, and proceed recursively. class CallInliner : public HloPassInterface { public: + // Inlines one call instruction. + static Status Inline(HloInstruction* call); + ~CallInliner() override = default; tensorflow::StringPiece name() const override { return "CallInliner"; } diff --git a/tensorflow/compiler/xla/service/call_inliner_test.cc b/tensorflow/compiler/xla/service/call_inliner_test.cc index f3e7407c544c0443ffb24536d94787f8d12a7bbe..865ed993da121d26ceb61123f1822d93814cbb9b 100644 --- a/tensorflow/compiler/xla/service/call_inliner_test.cc +++ b/tensorflow/compiler/xla/service/call_inliner_test.cc @@ -115,5 +115,54 @@ TEST_F(CallInlinerTest, CallsWithinWhileBodiesAreInlined) { op::Constant()); } +// Check CallInliner::Inline, which inlines a specific call without running the +// whole pass. +TEST_F(CallInlinerTest, InlineWithoutRunningPass) { + const Shape pred = ShapeUtil::MakeShape(PRED, {}); + auto module = CreateNewModule(); + + HloComputation::Builder just_false(TestName() + ".false"); + auto* true_constant = just_false.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({true}))); + auto* false_constant = just_false.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + TF_ASSERT_OK(false_constant->AddControlDependencyTo(true_constant)); + HloComputation* false_computation = + module->AddEmbeddedComputation(just_false.Build()); + + HloComputation::Builder call_false_builder(TestName() + ".call_false"); + HloInstruction* call = call_false_builder.AddInstruction( + HloInstruction::CreateCall(pred, {}, false_computation)); + auto computation = module->AddEntryComputation(call_false_builder.Build()); + + TF_ASSERT_OK(CallInliner::Inline(call)); + EXPECT_THAT(computation->root_instruction(), op::Constant()); + EXPECT_THAT(computation->root_instruction()->control_successors(), + ElementsAre(op::Constant())); +} + +TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) { + const Shape f32 = ShapeUtil::MakeShape(F32, {}); + auto module = CreateNewModule(); + + HloComputation::Builder outfeeder(TestName() + ".outfeeder"); + auto value = outfeeder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42.0))); + outfeeder.AddInstruction( + HloInstruction::CreateOutfeed(f32, value, /*outfeed_config=*/"")); + + auto outfeed_computation = module->AddEmbeddedComputation(outfeeder.Build()); + + HloComputation::Builder outer(TestName() + ".outer"); + outer.AddInstruction(HloInstruction::CreateCall( + ShapeUtil::MakeNil(), /*operands=*/{}, outfeed_computation)); + + module->AddEntryComputation(outer.Build()); + + CallInliner call_inliner; + TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get())); + ASSERT_TRUE(mutated); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/channel_tracker.cc b/tensorflow/compiler/xla/service/channel_tracker.cc index b3784c36ff68a175c2f85b6e49419cc090766b80..a5b392cbc33c12c3255f3c06e9842fc116e672e5 100644 --- a/tensorflow/compiler/xla/service/channel_tracker.cc +++ b/tensorflow/compiler/xla/service/channel_tracker.cc @@ -69,7 +69,10 @@ Status ChannelTracker::RegisterSendInternal(const ChannelHandle& handle) { } Channel& channel = opaque_to_channel_[handle.handle()]; if (channel.has_sender) { - return FailedPrecondition("channel handle is already used by a sender"); + return FailedPrecondition( + "when registering send, passed a channel handle that is already used " + "by a sender: %lld", + handle.handle()); } channel.has_sender = true; return Status::OK(); @@ -82,7 +85,10 @@ Status ChannelTracker::RegisterRecvInternal(const ChannelHandle& handle) { Channel& channel = opaque_to_channel_[handle.handle()]; // TODO(b/33942691): Allow more than 1 receivers for broadcast. if (channel.receiver_count >= 1) { - return FailedPrecondition("channel handle is already used by a receiver"); + return FailedPrecondition( + "when registering recv, passed a channel handle that is already used " + "by a receiver: %lld", + handle.handle()); } channel.receiver_count += 1; return Status::OK(); diff --git a/tensorflow/compiler/xla/service/compile_only_service.cc b/tensorflow/compiler/xla/service/compile_only_service.cc index c95670b1954bada51488a8b3722ca911b98b69a2..9e96898d9b4215e67c8686d372e4b4e6edd1d88b 100644 --- a/tensorflow/compiler/xla/service/compile_only_service.cc +++ b/tensorflow/compiler/xla/service/compile_only_service.cc @@ -101,8 +101,7 @@ CompileOnlyService::CompileAheadOfTime( TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, CreateModuleConfig(*program_shape, instance.argument_layouts, - &execution_options, - /*has_hybrid_result=*/false)); + &execution_options)); TF_ASSIGN_OR_RETURN(std::unique_ptr hlo_module, computation_tracker_.BuildHloModule( diff --git a/tensorflow/compiler/xla/service/compiler.cc b/tensorflow/compiler/xla/service/compiler.cc index f71b2b6b9c65c63e6ca211004b1df5cc39aef5fa..3b1900428af1863c73efe67c27061d979557b3a4 100644 --- a/tensorflow/compiler/xla/service/compiler.cc +++ b/tensorflow/compiler/xla/service/compiler.cc @@ -58,7 +58,8 @@ Compiler::GetPlatformCompilers() { LazyInitMutex(); tensorflow::mutex_lock lock(*platform_compiler_mutex_); auto* factories = GetPlatformCompilerFactories(); - CHECK(factories->find(platform_id) == factories->end()); + CHECK(factories->find(platform_id) == factories->end()) + << "Compiler factory already registered for platform"; (*factories)[platform_id] = std::move(compiler_factory); } diff --git a/tensorflow/compiler/xla/service/compiler.h b/tensorflow/compiler/xla/service/compiler.h index d5bd9214be44f4abd5f672168335ae1a259c9118..4c2d9600d909e82dcb62f508a10445c08c1cdee6 100644 --- a/tensorflow/compiler/xla/service/compiler.h +++ b/tensorflow/compiler/xla/service/compiler.h @@ -114,7 +114,8 @@ class Compiler { // sequence of executable objects. virtual StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) = 0; + std::vector> + stream_exec) = 0; // Compiles the HLO module for ahead-of-time execution. This is intended for // use in static compilation. diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 628f729e0b4388cf258ec7f393f14c48042c1e3e..0453a698a09b740d68b35258ede7c537fcf290d4 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -532,11 +532,11 @@ StatusOr CopyInsertion::Run(HloModule* module) { // Gather all while body computations and while instructions. FlatSet while_body_computations; std::vector while_instructions; - for (auto& computation : module->computations()) { - for (auto& instruction : computation->instructions()) { + for (auto* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kWhile) { while_body_computations.insert(instruction->while_body()); - while_instructions.push_back(instruction.get()); + while_instructions.push_back(instruction); } } } @@ -546,14 +546,11 @@ StatusOr CopyInsertion::Run(HloModule* module) { // Add copies of computation root instructions, if needed. FlatMap> while_body_read_only_indices; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module->MakeNonfusionComputations()) { VLOG(2) << "computation " << computation->name(); InstructionCopier root_copier(computation->root_instruction(), /*copy_users=*/{}); - if (while_body_computations.count(computation.get()) > 0) { + if (while_body_computations.count(computation) > 0) { // Record root indices to copy for while body sub-computations. We do not // need to call RecordIndicesWhichPointToParamOrConstant for the while // body root instruction here, because any necessary copies needed to @@ -563,7 +560,7 @@ StatusOr CopyInsertion::Run(HloModule* module) { ShapeTree read_only_indices(while_body_param->shape()); TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( *liveness, while_body_param, &read_only_indices)); - while_body_read_only_indices[computation.get()] = read_only_indices; + while_body_read_only_indices[computation] = read_only_indices; // Mark control predecessors, based on the body param, for any copies // we'll be inserting. This ensures the copy doesn't run too early. diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index d7a363b8783b5c86ce3b4edeedc5d4a3dbd2d159..6213baee2fa5c4af7c650d0be4af619deba2709a 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -27,6 +27,50 @@ filegroup( ]), ) +cc_library( + name = "cpu_transfer_manager", + srcs = ["cpu_transfer_manager.cc"], + hdrs = ["cpu_transfer_manager.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/service/cpu:cpu_runtime", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + +cc_library( + name = "external_constant_pool", + srcs = ["external_constant_pool.cc"], + hdrs = ["external_constant_pool.h"], + deps = [ + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:util", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "external_constant_pool_test", + srcs = ["external_constant_pool_test.cc"], + deps = [ + ":external_constant_pool", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:test", + ], +) + cc_library( name = "cpu_compiler", srcs = ["cpu_compiler.cc"], @@ -43,6 +87,7 @@ cc_library( ":ir_emitter", ":layout_assignment", ":parallel_cpu_executable", + ":parallel_task_assignment", ":simple_orc_jit", "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:protobuf_util", @@ -76,6 +121,8 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", # fixdeps: keep "//tensorflow/core:lib", # fixdeps: keep "//tensorflow/core:stream_executor_no_cuda", @@ -100,15 +147,17 @@ cc_library( name = "simple_orc_jit", srcs = ["simple_orc_jit.cc"], hdrs = ["simple_orc_jit.h"], - linkopts = ["-ldl"], deps = [ ":compiler_functor", ":cpu_runtime", ":cpu_runtime_avx", ":cpu_runtime_neon", ":cpu_runtime_sse4_1", + ":custom_call_target_registry", ":disassembler", + ":external_constant_pool", ":runtime_conv2d", + ":runtime_fork_join", ":runtime_matmul", ":runtime_single_threaded_conv2d", ":runtime_single_threaded_matmul", @@ -195,7 +244,9 @@ cc_library( ":cpu_options", ":cpu_runtime", ":dot_op_emitter", + ":external_constant_pool", ":ir_emission_utils", + ":shape_partition", ":simple_orc_jit", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", @@ -216,6 +267,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", "@llvm//:support", @@ -457,9 +509,24 @@ cc_library( ], ) +cc_library( + name = "runtime_fork_join", + srcs = ["runtime_fork_join.cc"], + hdrs = ["runtime_fork_join.h"], + copts = runtime_copts(), + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/xla:executable_run_options", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//third_party/eigen3", + ], +) + tf_cc_test( name = "cpu_runtime_test", srcs = ["cpu_runtime_test.cc"], + tags = ["optonly"], deps = [ ":cpu_runtime", ":runtime_matmul", @@ -521,6 +588,7 @@ cc_library( ], deps = [ ":ir_emission_utils", + ":parallel_task_assignment", ":shape_partition", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", @@ -630,6 +698,19 @@ tf_cc_test( ], ) +cc_library( + name = "parallel_task_assignment", + srcs = ["parallel_task_assignment.cc"], + hdrs = ["parallel_task_assignment.h"], + deps = [ + ":ir_emission_utils", + ":shape_partition", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_cost_analysis", + "//tensorflow/compiler/xla/service:hlo_pass", + ], +) + cc_library( name = "cpu_options", srcs = ["cpu_options.cc"], @@ -639,6 +720,17 @@ cc_library( ], ) +cc_library( + name = "custom_call_target_registry", + srcs = [ + "custom_call_target_registry.cc", + ], + hdrs = [ + "custom_call_target_registry.h", + ], + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc index 069979c6611e90ed2d95cbbe341198577cdf56cf..44cd2171afdc6eecc22f3f920276a4d95f930573 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization.cc @@ -36,8 +36,8 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { !PotentiallyImplementedAsEigenConvolution(*hlo)) { const ConvolutionDimensionNumbers& dnums = hlo->convolution_dimension_numbers(); - auto batch_dim = dnums.batch_dimension(); - auto feature_dim = dnums.feature_dimension(); + auto input_batch_dim = dnums.input_batch_dimension(); + auto input_feature_dim = dnums.input_feature_dimension(); auto kernel_input_feature_dim = dnums.kernel_input_feature_dimension(); auto kernel_output_feature_dim = dnums.kernel_output_feature_dimension(); @@ -59,15 +59,16 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { std::vector new_input_dim_order(num_dims); std::vector new_input_dims(num_dims); - new_input_dim_order[0] = batch_dim; - new_input_dims[0] = input->shape().dimensions(batch_dim); + new_input_dim_order[0] = input_batch_dim; + new_input_dims[0] = input->shape().dimensions(input_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_input_dim_order[i + 1] = dnums.spatial_dimensions(i); new_input_dims[i + 1] = input->shape().dimensions(dnums.spatial_dimensions(i)); } - new_input_dim_order[num_dims - 1] = feature_dim; - new_input_dims[num_dims - 1] = input->shape().dimensions(feature_dim); + new_input_dim_order[num_dims - 1] = input_feature_dim; + new_input_dims[num_dims - 1] = + input->shape().dimensions(input_feature_dim); Shape new_input_shape = ShapeUtil::MakeShape(input->shape().element_type(), new_input_dims); @@ -98,22 +99,26 @@ StatusOr ConvCanonicalization::Run(HloModule* module) { new_kernel_dim_order)); std::vector new_conv_dims(num_dims); - new_conv_dims[0] = hlo->shape().dimensions(batch_dim); + auto output_batch_dim = dnums.output_batch_dimension(); + auto output_feature_dim = dnums.output_feature_dimension(); + new_conv_dims[0] = hlo->shape().dimensions(output_batch_dim); for (int i = 0; i < num_spatial_dims; ++i) { new_conv_dims[i + 1] = hlo->shape().dimensions(dnums.spatial_dimensions(i)); } - new_conv_dims[num_dims - 1] = hlo->shape().dimensions(feature_dim); + new_conv_dims[num_dims - 1] = hlo->shape().dimensions(output_feature_dim); Shape new_conv_shape = ShapeUtil::MakeShape(hlo->shape().element_type(), new_conv_dims); ConvolutionDimensionNumbers new_dnums; - new_dnums.set_batch_dimension(0); + new_dnums.set_input_batch_dimension(0); + new_dnums.set_output_batch_dimension(0); for (int i = 0; i < num_spatial_dims; ++i) { new_dnums.add_spatial_dimensions(i + 1); new_dnums.add_kernel_spatial_dimensions(i); } - new_dnums.set_feature_dimension(num_dims - 1); + new_dnums.set_input_feature_dimension(num_dims - 1); + new_dnums.set_output_feature_dimension(num_dims - 1); new_dnums.set_kernel_input_feature_dimension(num_dims - 2); new_dnums.set_kernel_output_feature_dimension(num_dims - 1); diff --git a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc index 9e8b785f30559f493bcec546e0612f2290af031d..d593ba26b655d00a0f0f0b9a94c9e62fa1835080 100644 --- a/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc +++ b/tensorflow/compiler/xla/service/cpu/conv_canonicalization_test.cc @@ -67,10 +67,12 @@ TEST_F(ConvCanonicalizationTest, NonCanonicalToCanonical) { kOutputFeatureCount, kInputFeatureCount, kWindowSize, kWindowSize)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(1); + dnums.set_input_batch_dimension(1); + dnums.set_output_batch_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(0); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_kernel_spatial_dimensions(2); dnums.add_kernel_spatial_dimensions(3); dnums.set_kernel_input_feature_dimension(1); @@ -121,10 +123,12 @@ TEST_F(ConvCanonicalizationTest, CanonicalStaysTheSame) { kWindowSize, kWindowSize, kInputFeatureCount, kOutputFeatureCount)))); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc index 5b90b6b7f0d88b55430af837f7a6da580ed14d88..487ea003be643a9c7d48dc2d0037ba6a0ae498dd 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc @@ -58,6 +58,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/layout_assignment.h" #include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" @@ -80,15 +81,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/lib/strings/strcat.h" -#include "tensorflow/core/platform/env.h" namespace se = ::perftools::gputools; @@ -222,14 +223,9 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { } // Skip constants, there is nothing to profile. - Status HandleConstant(HloInstruction* /*constant*/, - const Literal& /*literal*/) override { - return Status::OK(); - } + Status HandleConstant(HloInstruction*) override { return Status::OK(); } // Skip parameters, they are a simple load. - Status HandleParameter(HloInstruction* /*parameter*/) override { - return Status::OK(); - } + Status HandleParameter(HloInstruction*) override { return Status::OK(); } // It is important to recurse for "while" or else we risk overly coarse // profiling information. Status HandleWhile(HloInstruction* xla_while) override { @@ -249,7 +245,7 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault { }; } // namespace -Status CpuCompiler::RunHloPasses(HloModule* module) { +Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) { // Optimization pipeline. HloPassPipeline pipeline("CPU"); pipeline.AddInvariantChecker(ShapeSizeBytesFunction()); @@ -270,6 +266,8 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { { auto& pass = pipeline.AddPass>("simplification"); + pass.AddInvariantChecker(ShapeSizeBytesFunction()); + pass.AddPass( /*rewrite_training_op=*/true, /*rewrite_inference_op=*/true, @@ -279,6 +277,9 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }, /*enable_dot_simplification=*/false); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -314,6 +315,14 @@ Status CpuCompiler::RunHloPasses(HloModule* module) { if (options::CpuParallelBackendRequested(module->config())) { pipeline.AddPass(max_parallelism, ShapeSizeBytesFunction()); + } else if (!is_aot_compile) { + // Run ParallelTaskAssigner to assign parallel tasks to HLOs in module. + // Note this is not run for AOT because it would bring in thread pool + // and thread synchronization dependencies which would likely increase + // binary size (and most AOT applications are single-threaded). + // TODO(29630486) Support multi-threaded AOT. + pipeline.AddPass(max_parallelism, + ShapeSizeBytesFunction(), module); } // Copy insertion should be performed immediately before IR emission to avoid // inserting unnecessary copies (later pass adds an instruction which @@ -365,68 +374,50 @@ llvm::CodeGenOpt::Level CodeGenOptLevel(const HloModuleConfig& module_config) { } } -Status AppendIRToFile(const string& file_name, const string& ir_module_string) { - std::unique_ptr f; - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->NewWritableFile(file_name, &f)); - TF_RETURN_IF_ERROR(f->Append(ir_module_string)); - TF_RETURN_IF_ERROR(f->Close()); - return Status::OK(); -} - Status InitializeModuleHooks( - const HloModule& module, + const HloModule& hlo_module, const LLVMCompiler::ModuleHook& user_pre_optimization_hook, const LLVMCompiler::ModuleHook& user_post_optimization_hook, LLVMCompiler::ModuleHook* pre_optimization_ir_hook, LLVMCompiler::ModuleHook* post_optimization_ir_hook) { - const string& dump_ir_to = module.config().debug_options().xla_dump_ir_to(); - if (dump_ir_to.empty()) { + const string& ir_dump_directory = + hlo_module.config().debug_options().xla_dump_ir_to(); + if (ir_dump_directory.empty()) { *pre_optimization_ir_hook = user_pre_optimization_hook; *post_optimization_ir_hook = user_post_optimization_hook; return Status::OK(); } - // Initialize the output directory and create the output file names. - TF_RETURN_IF_ERROR( - tensorflow::Env::Default()->RecursivelyCreateDir(dump_ir_to)); - string safe_file_name_base = module.name(); - std::replace_if(safe_file_name_base.begin(), safe_file_name_base.end(), - [](char c) { return c == '/' || c == '\\'; }, '_'); - - string unoptimized_ir_file_name = tensorflow::io::JoinPath( - dump_ir_to, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-no-opt.ll")); - string optimized_ir_file_name = tensorflow::io::JoinPath( - dump_ir_to, - tensorflow::strings::StrCat("ir-", safe_file_name_base, "-opt.ll")); + const string& hlo_module_name = hlo_module.name(); // Create the IR hooks. If applicable, each IR hook does the following: - // * Call the user supplied module hook. - // * Write to the output directory. Files will be appended to. We still want - // to append to avoid overwriting possibly important information due to - // operator error. + // + // * Calls the user supplied module hook. + // * Writes out the IR to a file in the output directory designated by + // --xla_dump_ir_to *pre_optimization_ir_hook = - [user_pre_optimization_hook, - unoptimized_ir_file_name](const llvm::Module& module) { + [user_pre_optimization_hook, ir_dump_directory, + hlo_module_name](const llvm::Module& llvm_module) { if (user_pre_optimization_hook) { - TF_RETURN_IF_ERROR(user_pre_optimization_hook(module)); + TF_RETURN_IF_ERROR(user_pre_optimization_hook(llvm_module)); } - TF_RETURN_IF_ERROR(AppendIRToFile(unoptimized_ir_file_name, - llvm_ir::DumpModuleToString(module))); - return Status::OK(); + return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/hlo_module_name, + llvm_module, + /*optimized=*/false); }; *post_optimization_ir_hook = - [user_post_optimization_hook, - optimized_ir_file_name](const llvm::Module& module) { + [user_post_optimization_hook, ir_dump_directory, + hlo_module_name](const llvm::Module& llvm_module) { if (user_post_optimization_hook) { - TF_RETURN_IF_ERROR(user_post_optimization_hook(module)); + TF_RETURN_IF_ERROR(user_post_optimization_hook(llvm_module)); } - TF_RETURN_IF_ERROR(AppendIRToFile(optimized_ir_file_name, - llvm_ir::DumpModuleToString(module))); - return Status::OK(); + return llvm_ir::DumpIRToDirectory(/*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/hlo_module_name, + llvm_module, + /*optimized=*/true); }; return Status::OK(); @@ -466,7 +457,13 @@ StatusOr> CpuCompiler::Compile( llvm_module->setDataLayout(jit->data_layout()); llvm_module->setTargetTriple(jit->target_triple().getTriple()); - TF_RETURN_IF_ERROR(RunHloPasses(module.get())); + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module.get(), /*is_aot_compile=*/false)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); HloComputation* computation = module->entry_computation(); std::unordered_map hlo_to_profile_idx; @@ -482,8 +479,8 @@ StatusOr> CpuCompiler::Compile( // ownership is std::moved. const bool embed_ir_in_executable = module->config().debug_options().xla_embed_ir_in_executable(); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); if (options::CpuParallelBackendRequested(module->config())) { VLOG(1) << "Using parallel cpu backend"; @@ -503,10 +500,10 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } // If we are using the parallel CPU backend, we need to create map from @@ -540,10 +537,11 @@ StatusOr> CpuCompiler::Compile( } IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); - std::unique_ptr> function_names( - new std::map()); + std::unique_ptr> function_names( + new HloInstructionMap()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { if (embedded_computation->IsFusionComputation()) { @@ -609,18 +607,18 @@ StatusOr> CpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - if (!dump_debug_json_to.empty()) { + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } - // Each computation is a single function. Emit all embedded computations // before the entry computation. The order of computations returned from // GetEmbeddedComputations guarantees that a called computation occurs // before a caller computation. IrEmitter ir_emitter(*module, *assignment, llvm_module.get(), - &hlo_to_profile_idx, jit->target_machine()); + &hlo_to_profile_idx, jit->target_machine(), + jit->external_constant_pool()); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { @@ -667,7 +665,7 @@ StatusOr> CpuCompiler::Compile( StatusOr>> CpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on CPU."); } @@ -763,7 +761,13 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, HloModule* module = modules[i].get(); VLOG(1) << "Compiling ahead-of-time: " << module->name(); - TF_RETURN_IF_ERROR(RunHloPasses(module)); + VLOG(2) << "Before optimization:"; + XLA_VLOG_LINES(2, module->ToString()); + + TF_RETURN_IF_ERROR(RunHloPasses(module, /*is_aot_compile=*/true)); + + VLOG(2) << "After optimization:"; + XLA_VLOG_LINES(2, module->ToString()); TF_ASSIGN_OR_RETURN( SequentialHloOrdering::HloModuleSequence module_sequence, @@ -780,16 +784,17 @@ CpuCompiler::CompileAheadOfTime(std::vector> modules, // print one ourselves. XLA_VLOG_LINES(2, assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitter ir_emitter(*module, *assignment, &llvm_module, - /*hlo_to_profile_idx=*/nullptr, target_machine.get()); + /*hlo_to_profile_idx=*/nullptr, target_machine.get(), + /*external_constant_pool=*/nullptr); HloComputation* computation = module->entry_computation(); for (auto embedded_computation : computation->MakeEmbeddedComputationsList()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h index bd3541500dae9d9d59c56bfb062912a1b85c2219..d09130247421b11d6d4879466f39b89167eb9564 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.h @@ -115,7 +115,8 @@ class CpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> modules, @@ -131,7 +132,7 @@ class CpuCompiler : public LLVMCompiler { // Runs the HLO passes which are necessary for both optimizations and // correctness. - Status RunHloPasses(HloModule* module); + Status RunHloPasses(HloModule* module, bool is_aot_compile); TF_DISALLOW_COPY_AND_ASSIGN(CpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc index 6cc1d65c7afe50cbe7ee84a3d8c5cbfee1993f9a..f62353bee7b1058dc237169b70341c33ab19fc52 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc @@ -51,8 +51,9 @@ namespace cpu { CpuExecutable::CpuExecutable( std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, const string& entry_function_name, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + const string& entry_function_name, std::unordered_map hlo_to_profile_idx) : Executable(std::move(hlo_module)), jit_(std::move(jit)), @@ -147,7 +148,6 @@ Status CpuExecutable::ExecuteComputeFunction( HloExecutionProfile* hlo_execution_profile) { std::vector argument_buffers; for (int i = 0; i < arguments.size(); ++i) { - TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape())); argument_buffers.push_back(arguments[i]->buffer(/*index=*/{})); } return ExecuteComputeFunction(run_options, argument_buffers, buffers, @@ -234,7 +234,7 @@ Status CpuExecutable::ExecuteComputeFunction( for (auto hlo_prof_idx : hlo_to_profile_idx_) { const HloInstruction* hlo = hlo_prof_idx.first; uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->AddProfileResult(hlo, cycles_taken); + hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); } } return Status::OK(); @@ -298,10 +298,10 @@ StatusOr> CpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - TF_ASSIGN_OR_RETURN(std::unique_ptr result_buffer, - ShapedBuffer::MakeShapedBuffer( - result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal())); + auto result_buffer = + MakeUnique(result_shape(), stream->parent()->platform(), + stream->parent()->device_ordinal()); + TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); TF_RETURN_IF_ERROR(ExecuteComputeFunction( @@ -315,32 +315,29 @@ StatusOr> CpuExecutable::ExecuteOnStream( ->ForEachMutableElementWithStatus( [&buffers, &buffers_in_result, &result_buffer, this]( const ShapeIndex& index, size_t* buffer_entry) { - if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. - CHECK_EQ(1, sources.size()); - const LogicalBuffer* buffer_source = sources[0]; - HloInstruction* src = buffer_source->instruction(); - - // The source for this result buffer can be a nested buffer - // such as a tuple element. - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src, buffer_source->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - const BufferAllocation::Index buffer_index = slice.index(); - const se::DeviceMemoryBase& buffer = buffers[buffer_index]; - CHECK(!buffer.is_null() || buffer.size() == 0); - *buffer_entry = result_buffer->mutable_buffers()->size(); - result_buffer->mutable_buffers()->push_back(buffer); - buffers_in_result[buffer_index] = true; - } + const auto& sources = this->GetRootPointsToSet().element(index); + // The points to set is unambiguous so the set should be a + // singleton. + CHECK_EQ(1, sources.size()); + const LogicalBuffer* buffer_source = sources[0]; + HloInstruction* src = buffer_source->instruction(); + + // The source for this result buffer can be a nested buffer + // such as a tuple element. + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( + src, buffer_source->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + const BufferAllocation::Index buffer_index = slice.index(); + const se::DeviceMemoryBase& buffer = buffers[buffer_index]; + CHECK(!buffer.is_null() || buffer.size() == 0); + *buffer_entry = result_buffer->mutable_buffers()->size(); + result_buffer->mutable_buffers()->push_back(buffer); + buffers_in_result[buffer_index] = true; return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h index a64537eaa3e3baefaefcc618ac971b7559badd94..238bc9b46ae2bf1b519eaf137d9ae063e769bd2e 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h @@ -49,8 +49,8 @@ class CpuExecutable : public Executable { public: CpuExecutable( std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, + std::unique_ptr assignment, + std::unique_ptr hlo_module, const string& entry_function_name, std::unordered_map hlo_to_profile_idx); ~CpuExecutable() override {} @@ -87,6 +87,17 @@ class CpuExecutable : public Executable { std::unique_ptr CreateCostAnalysis() const override; + // Type of the computation function we expect in the JIT. + using ComputeFunctionType = void (*)( + void* /*result*/, const ExecutableRunOptions* /*run_options*/, + const void** /*args*/, void** /*temps*/, uint64* /*profile_counters*/); + + const ComputeFunctionType& compute_function() const { + return compute_function_; + } + + const BufferAssignment& buffer_assignment() const { return *assignment_; } + private: // Allocate buffers required for execution and assign them to the elements of // "buffers". "buffers" should be sized to the number of buffers in buffer @@ -118,10 +129,10 @@ class CpuExecutable : public Executable { const PointsToSet& GetRootPointsToSet() const; // The JIT containing compiled modules. - std::unique_ptr jit_; + const std::unique_ptr jit_; // Buffer assignment for the buffers we need to allocate. - std::unique_ptr assignment_; + const std::unique_ptr assignment_; // The LLVM IR, in string format, of the unoptimized module generated for this // CpuExecutable. We save a string instead of an llvm::Module* because leaving @@ -129,11 +140,6 @@ class CpuExecutable : public Executable { // positives. string ir_module_string_; - // Type of the computation function we expect in the JIT. - // void function(void* result, const void* run_options, - // const void** args_array, void** temps_array) - using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, - uint64*); ComputeFunctionType compute_function_; // Entry function name for the computation. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc index e23fd3d35807a9a4e7571d62b06474b2c2fad733..f87ee3cecd932faac140636a3db7cd4aa0371b85 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.cc @@ -29,13 +29,17 @@ int64 BytesInDimension(const Shape& shape, int64 dimension) { bool IsFusile(const HloInstruction& hlo) { // These are the only ones we fuse since we rely on effective elemental IR // generation. - return (hlo.opcode() == HloOpcode::kBroadcast || - hlo.opcode() == HloOpcode::kReshape || - hlo.opcode() == HloOpcode::kBitcast || - hlo.opcode() == HloOpcode::kReverse || - hlo.opcode() == HloOpcode::kSlice || - hlo.opcode() == HloOpcode::kDynamicSlice || - hlo.opcode() == HloOpcode::kTranspose || hlo.IsElementwise()); + return hlo.IsElementwise() || // + hlo.opcode() == HloOpcode::kBitcast || + hlo.opcode() == HloOpcode::kBroadcast || + hlo.opcode() == HloOpcode::kConcatenate || + hlo.opcode() == HloOpcode::kDynamicSlice || + hlo.opcode() == HloOpcode::kDynamicUpdateSlice || + hlo.opcode() == HloOpcode::kPad || + hlo.opcode() == HloOpcode::kReshape || + hlo.opcode() == HloOpcode::kReverse || + hlo.opcode() == HloOpcode::kSlice || + hlo.opcode() == HloOpcode::kTranspose; } } // namespace @@ -113,15 +117,8 @@ bool CpuInstructionFusion::ShouldFuse(HloInstruction* consumer, return true; } - if (consumer->IsElementwise()) { - VLOG(2) << "Fusing: consumer is elementwise."; - return true; - } - - // TODO(b/66271886): Figure out which consumers should be fused into. At the - // moment, this is ad-hoc. - if (consumer->opcode() == HloOpcode::kDynamicUpdateSlice) { - VLOG(2) << "Fusing: consumer is dynamic-update-slice."; + if (IsFusile(*consumer)) { + VLOG(2) << "Fusing: consumer is elementwise or fusile."; return true; } diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc index 9e40c3b52035219cca3ccadb5fc0b2930b272bcd..b9e4d006d77ae76e33ac51440349400ea4eff118 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc @@ -198,12 +198,10 @@ class OpcodeFusionTest : public InstructionFusionTest { ASSERT_THAT(root, op::Fusion()); EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop); - std::vector fused_opcodes(root->fused_instructions().size()); + std::vector fused_opcodes(root->fused_instruction_count()); std::transform(root->fused_instructions().begin(), root->fused_instructions().end(), fused_opcodes.begin(), - [](const std::unique_ptr& hlo) { - return hlo->opcode(); - }); + [](const HloInstruction* hlo) { return hlo->opcode(); }); EXPECT_EQ( std::multiset(fused_opcodes.begin(), fused_opcodes.end()), @@ -502,6 +500,114 @@ TEST_F(OpcodeFusionTest, DynamicSliceWithDynamicUpdateSlice) { HloOpcode::kParameter, HloOpcode::kParameter}); } +TEST_F(OpcodeFusionTest, MessOfFusileNodes) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + Shape full_shape = ShapeUtil::MakeShape(F32, {4, 100, 10, 100, 50}); + + auto loop_idx = builder.AddInstruction(HloInstruction::CreateReshape( + ShapeUtil::MakeShape(S32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(S32, {}), "param0")))); + + auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( + 1, ShapeUtil::MakeShape(S32, {1}), "param1")); + auto concat = builder.AddInstruction(HloInstruction::CreateConcatenate( + ShapeUtil::MakeShape(S32, {5}), + {loop_idx, param1, param1, param1, param1}, /*dimension=*/0)); + + auto idx_choice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(S32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + 2, ShapeUtil::MakeShape(S32, {4}), "param2")), + loop_idx, + /*slice_sizes=*/{1})); + + PaddingConfig padding_config; + padding_config.add_dimensions()->set_edge_padding_high(4); + auto pad = builder.AddInstruction(HloInstruction::CreatePad( + ShapeUtil::MakeShape(S32, {5}), idx_choice, + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + padding_config)); + + auto slice = builder.AddInstruction(HloInstruction::CreateDynamicSlice( + ShapeUtil::MakeShape(F32, {1, 100, 10, 100, 50}), + builder.AddInstruction(HloInstruction::CreateParameter( + 3, ShapeUtil::MakeShape(F32, {100, 100, 10, 100, 50}), "param3")), + pad, /*slice_sizes=*/{1, 100, 10, 100, 50})); + + builder.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( + full_shape, + builder.AddInstruction( + HloInstruction::CreateParameter(4, full_shape, "param4")), + slice, concat)); + + module->AddEntryComputation(builder.Build()); + RunFusionAndCheckOpcodesWereFused( + module.get(), + {HloOpcode::kConcatenate, HloOpcode::kPad, HloOpcode::kDynamicSlice, + HloOpcode::kDynamicSlice, HloOpcode::kDynamicUpdateSlice, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter, + HloOpcode::kParameter, HloOpcode::kParameter, HloOpcode::kParameter}); +} + +// Tests that we do not fuse instructions in cases where instructions in the +// fusion would reuse elements from its operand due to an implicit broadcast. +TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastUnary) { + Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); + Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); + + HloComputation::Builder builder(TestName()); + + HloInstruction* small_param = + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, small_shape, "param")); + HloInstruction* small_exp = builder.AddInstruction( + HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); + builder.AddInstruction( + HloInstruction::CreateUnary(large_shape, HloOpcode::kExp, small_exp)); + + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto did_fusion = CpuInstructionFusion().Run(module.get()); + ASSERT_TRUE(did_fusion.ok()); + EXPECT_FALSE(did_fusion.ValueOrDie()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + +// Like ReuseViaImplicitBroadcastUnary but with a binary operation. +TEST_F(OpcodeFusionTest, ReuseViaImplicitBroadcastBinary) { + Shape small_shape = ShapeUtil::MakeShape(F32, {1, 4}); + Shape large_shape = ShapeUtil::MakeShape(F32, {3, 4}); + + HloComputation::Builder builder(TestName()); + + HloInstruction* small_param = + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, small_shape, "param")); + HloInstruction* large_param = + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, large_shape, "param")); + HloInstruction* small_exp = builder.AddInstruction( + HloInstruction::CreateUnary(small_shape, HloOpcode::kExp, small_param)); + + builder.AddInstruction(HloInstruction::CreateBinary( + large_shape, HloOpcode::kAdd, small_exp, large_param)); + + std::unique_ptr module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + auto did_fusion = CpuInstructionFusion().Run(module.get()); + ASSERT_TRUE(did_fusion.ok()); + EXPECT_FALSE(did_fusion.ValueOrDie()); + ASSERT_THAT(module->entry_computation()->root_instruction(), + Not(op::Fusion())); +} + } // namespace } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc index 0283cc64341a38824a29e95dca6e53039becfd9e..662ee609232f5582ce74f4f515637b2623175e94 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" #include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -109,34 +110,15 @@ StatusOr ParallelizationPreparation::RunParallelTaskAssignment( HloModule* module) { VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_; bool changed = false; - // Run cost analysis on entry computation. - HloCostAnalysis cost_analysis(shape_size_); + // Initialize ParallelTaskAssignment. + ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_, + module); + // Assign parallel tasks to HLOs in entry computation. HloComputation* computation = module->entry_computation(); - Status cost_status = computation->root_instruction()->Accept(&cost_analysis); - for (auto& instruction : computation->instructions()) { - // Currently, we do not assign parallel tasks to instructions with at least - // one of the following properties: - // *) Internal threading (library calls to kConv, kDot, and kCustomCall). - // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). - // *) Tuple-shaped. - // TODO(b/27458679) Parallelize instructions which are skipped here. - if (instruction->opcode() == HloOpcode::kParameter || - instruction->opcode() == HloOpcode::kConstant || - instruction->opcode() == HloOpcode::kCall || - instruction->opcode() == HloOpcode::kCustomCall || - instruction->opcode() == HloOpcode::kSelectAndScatter || - (instruction->opcode() == HloOpcode::kConvolution && - PotentiallyImplementedAsEigenConvolution(*instruction)) || - PotentiallyImplementedAsEigenDot(*instruction) || - (instruction->opcode() == HloOpcode::kFusion && - instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || - ShapeUtil::IsTuple(instruction->shape())) { - continue; - } - + for (auto* instruction : computation->instructions()) { // Calculate target parallel task count in [1, max_parallelism_]. - const int64 target_parallel_task_count = GetTargetParallelTaskCount( - cost_status.ok() ? &cost_analysis : nullptr, instruction.get()); + const int64 target_parallel_task_count = + parallel_task_assignment.GetTargetParallelTaskCount(instruction); if (target_parallel_task_count == 1) { continue; } @@ -159,30 +141,6 @@ StatusOr ParallelizationPreparation::RunParallelTaskAssignment( return changed; } -int64 ParallelizationPreparation::GetTargetParallelTaskCount( - const HloCostAnalysis* cost_analysis, HloInstruction* instruction) { - // Default to a simple cost model based on hlo size and typical L2 cache size. - // Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an - // error status (likely because HLOs like CustomCall are not yet implemented - // in the HloCostAnalysis). - int64 instruction_cost = shape_size_(instruction->shape()); - int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. - if (cost_analysis != nullptr) { - // Calculate the instruction cost in cycles. - // TODO(29630486) Improve on this linear cost model. - // Consider making 'min_cost_per_thread' be a function of the target - // bandwidth limit for instructions with low arithmetic complexity. - instruction_cost = 1 * cost_analysis->flop_count(*instruction) + - 2 * cost_analysis->transcendental_count(*instruction) + - 10 * cost_analysis->bytes_accessed(*instruction); - // Minimum per-thread cost is 100us of work on a 2GHz core. - min_cost_per_thread = 100000; - } - // Return target parallel task count in [1, max_parallelism_]. - return std::min(max_parallelism_, - std::max(1LL, instruction_cost / min_cost_per_thread)); -} - bool ParallelizationPreparation::OutlineParallelizableInstruction( HloInstruction* instruction) { if (instruction->outer_dimension_partitions().empty()) { diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h index d53fc461509cad51778dba37922212731236952f..87be758ef5d0535fdce3a65e54ce225042019cdb 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h @@ -55,12 +55,6 @@ class ParallelizationPreparation : public HloPassInterface { // Returns true on success or error status otherwise. StatusOr RunParallelTaskAssignment(HloModule* module); - // Returns the target parallel task count for 'instruction'. - // Utilizes 'cost_analysis' if non-null. - // Otherwise defaults to a simple HLO output size-based cost model. - int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis, - HloInstruction* instruction); - // Outlines 'instruction' from entry computation, if it had // been assigned parallel tasks in an earlier pass through the computation. // Returns true if 'instruction' was successfully outlined, false otherwise. diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc index c7155b858bda5e5640e9a6719fb394ca1360d128..7908dc173d79a4a9dcb6127ac344267e27d2b5f2 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc @@ -51,6 +51,9 @@ extern const char* const kAcquireOutfeedBufferForPopulationSymbolName = "__xla_cpu_runtime_AcquireOutfeedBufferForPopulation"; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName = "__xla_cpu_runtime_ReleaseOutfeedBufferAfterPopulation"; +extern const char* const kParallelForkJoinSymbolName = + "__xla_cpu_runtime_ParallelForkJoin"; + extern const char* const kXlaCpuRuntimeSymbolNamePrefix = "__xla_cpu_runtime_"; } // namespace runtime } // namespace cpu diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h index 29feb7267fe97f6876827b6cbfa6217a0cecf238..2ade455b8a0a43dda8c93bbb79891439da2e4f75 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.h +++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.h @@ -51,6 +51,7 @@ extern const char* const kAcquireInfeedBufferForDequeueSymbolName; extern const char* const kReleaseInfeedBufferAfterDequeueSymbolName; extern const char* const kAcquireOutfeedBufferForPopulationSymbolName; extern const char* const kReleaseOutfeedBufferAfterPopulationSymbolName; +extern const char* const kParallelForkJoinSymbolName; // All symbol names for XLA CPU runtime functions need to start with this // prefix. diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc similarity index 98% rename from tensorflow/compiler/xla/service/cpu_transfer_manager.cc rename to tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc index b1b0cfdbe772bba918929afca6f2a3708ed789db..b53719fcc260d706eab3d7460c42af4a1b5e775f 100644 --- a/tensorflow/compiler/xla/service/cpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/cpu_transfer_manager.h" +#include "tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h" #include #include @@ -87,7 +87,8 @@ class CpuOutfeedBuffer : public cpu::runtime::XfeedBuffer { } // namespace CpuTransferManager::CpuTransferManager() - : GenericTransferManager(se::host::kHostPlatformId) {} + : GenericTransferManager(se::host::kHostPlatformId, + /*pointer_size=*/sizeof(void*)) {} Status CpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { diff --git a/tensorflow/compiler/xla/service/cpu_transfer_manager.h b/tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h similarity index 100% rename from tensorflow/compiler/xla/service/cpu_transfer_manager.h rename to tensorflow/compiler/xla/service/cpu/cpu_transfer_manager.h diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc new file mode 100644 index 0000000000000000000000000000000000000000..5f5803874b7886e56da47250d0dbe297f5db16c5 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.cc @@ -0,0 +1,39 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" + +namespace xla { +namespace cpu { + +CustomCallTargetRegistry* CustomCallTargetRegistry::Global() { + static auto* registry = new CustomCallTargetRegistry; + return registry; +} + +void CustomCallTargetRegistry::Register(const std::string& symbol, + void* address) { + std::lock_guard lock(mu_); + registered_symbols_[symbol] = address; +} + +void* CustomCallTargetRegistry::Lookup(const std::string& symbol) const { + std::lock_guard lock(mu_); + auto it = registered_symbols_.find(symbol); + return it == registered_symbols_.end() ? nullptr : it->second; +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h new file mode 100644 index 0000000000000000000000000000000000000000..2994642356d55df26c31553ef28dc653503d05be --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h @@ -0,0 +1,74 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ + +// This file is depended on by kernels that have to build for mobile devices. +// For this reason, we avoid relying on TensorFlow and instead only use the +// standard C++ library. + +#include // NOLINT +#include +#include + +namespace xla { +namespace cpu { + +// The CPU JIT compiler uses this registry to resolve symbolic CustomCall +// targets; so when using the CPU JIT, CustomCall targets need to be registered +// here with the symbol name used in the CustomCall. +// +// The XLA AOT compiler links using a standard offline linker; so when compiling +// in AOT mode, you *also* need to make sure the name of the callee (presumably +// implemented in C++) matches up with the symbolic name used in the CustomCall. +// +// We maintain the registry in both the JIT and the AOT cases for simplicity, +// but we only use it when running in JIT mode. +class CustomCallTargetRegistry { + public: + static CustomCallTargetRegistry* Global(); + + void Register(const std::string& symbol, void* address); + void* Lookup(const std::string& symbol) const; + + private: + std::unordered_map registered_symbols_; + mutable std::mutex mu_; +}; + +class RegisterCustomCallTarget { + public: + explicit RegisterCustomCallTarget(const std::string& name, void* address) { + CustomCallTargetRegistry::Global()->Register(name, address); + } +}; + +#define REGISTER_CUSTOM_CALL_CONCAT(a, b) a##b + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, counter) \ + static ::xla::cpu::RegisterCustomCallTarget REGISTER_CUSTOM_CALL_CONCAT( \ + custom_call_target_register, counter)(symbol, \ + reinterpret_cast(address)) + +#define REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(symbol, address) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM_HELPER(symbol, address, __COUNTER__) + +#define REGISTER_CUSTOM_CALL_TARGET(function) \ + REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(#function, function) + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CUSTOM_CALL_TARGET_REGISTRY_H_ diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index d3b94d75411218346cd25b0d3ecc3a9f30b56ba3..e57d49172b18beb75cfbb482c5d732ef679ebe41 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -63,7 +63,7 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs, llvm::Value* executable_run_options_value, llvm::IRBuilder<>* ir_builder, const HloModuleConfig& hlo_module_config) { PrimitiveType type = target_array.GetShape().element_type(); - TF_RET_CHECK(F32 == type || F64 == type); + TF_RET_CHECK(F32 == type || F64 == type || C64 == type); DotOpEmitter dot_emitter(dot, transpose_lhs, transpose_rhs, target_array, lhs_array, rhs_array, executable_run_options_value, ir_builder, hlo_module_config); @@ -176,7 +176,7 @@ tensorflow::Status DotOpEmitter::Emit() { llvm::BasicBlock* preheader_bb = reduction_loop->GetPreheaderBasicBlock(); ir_builder_->SetInsertPoint(preheader_bb->getTerminator()); - ir_builder_->CreateStore(llvm::ConstantFP::get(accum_type, 0.0), + ir_builder_->CreateStore(llvm::Constant::getNullValue(accum_type), accum_address); // Body basic block of reduction loop: @@ -191,9 +191,29 @@ tensorflow::Status DotOpEmitter::Emit() { llvm::Value* rhs_element = rhs_array_.EmitReadArrayElement(rhs_index, ir_builder_); - llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); llvm::Value* accum = ir_builder_->CreateLoad(accum_address); - llvm::Value* updated_accum = ir_builder_->CreateFAdd(accum, product); + llvm::Value* updated_accum; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_element), real(rhs_element)), + ir_builder_->CreateFMul(imag(lhs_element), imag(rhs_element))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_element), imag(rhs_element)), + ir_builder_->CreateFMul(imag(lhs_element), real(rhs_element))); + updated_accum = ir_builder_->CreateInsertValue( + accum, ir_builder_->CreateFAdd(real(accum), product_real), {0}); + updated_accum = ir_builder_->CreateInsertValue( + updated_accum, ir_builder_->CreateFAdd(imag(accum), product_imag), {1}); + } else { + llvm::Value* product = ir_builder_->CreateFMul(lhs_element, rhs_element); + updated_accum = ir_builder_->CreateFAdd(accum, product); + } ir_builder_->CreateStore(updated_accum, accum_address); // Exit basic block of reduction loop. @@ -230,11 +250,28 @@ tensorflow::Status DotOpEmitter::Emit() { tensorflow::Status DotOpEmitter::EmitScalarDot() { // A scalar dot is just a scalar multiply. + llvm::Value* result; llvm::Value* lhs_value = lhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); llvm::Value* rhs_value = rhs_array_.EmitReadArrayElement(/*index=*/{}, ir_builder_); - llvm::Value* result = ir_builder_->CreateFMul(lhs_value, rhs_value); + if (ShapeUtil::ElementIsComplex(lhs_array_.GetShape())) { +#define REAL(x) ir_builder_->CreateExtractValue(x, {0}) +#define IMAG(x) ir_builder_->CreateExtractValue(x, {1}) + llvm::Value* real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(REAL(lhs_value), REAL(rhs_value)), + ir_builder_->CreateFMul(IMAG(lhs_value), IMAG(rhs_value))); + llvm::Value* imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(REAL(lhs_value), IMAG(rhs_value)), + ir_builder_->CreateFMul(IMAG(lhs_value), REAL(rhs_value))); +#undef IMAG +#undef REAL + result = llvm::ConstantAggregateZero::get(lhs_array_.GetElementLlvmType()); + result = ir_builder_->CreateInsertValue(result, real, {0}); + result = ir_builder_->CreateInsertValue(result, imag, {1}); + } else { + result = ir_builder_->CreateFMul(lhs_value, rhs_value); + } target_array_.EmitWriteArrayElement(/*index=*/{}, result, ir_builder_); return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc index 73e039250ba62b1313c98965421f6d823ca6a3b0..ba693ec89ab7c4090f8c9d1e4d65f17a80d0ac55 100644 --- a/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.cc @@ -46,8 +46,8 @@ StatusOr CpuElementalIrEmitter::EmitFloatUnaryOp( } // Create function type for the function. llvm::FunctionType* function_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), - llvm_ir::PrimitiveTypeToIrType(element_type, ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), /*isVarArg=*/false); // Create function declaration for 'tanhf'. llvm::Function* function = diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9f8e5584965d0c73771750e26bd63c401d5b0c0 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.cc @@ -0,0 +1,53 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" + +#include +#include +#include + +#include "tensorflow/compiler/xla/map_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/core/lib/gtl/flatset.h" + +namespace xla { +namespace cpu { +void ExternalConstantPool::Insert(string name, const Literal& literal, + int64 alignment) { + CHECK(!ShapeUtil::IsTuple(literal.shape())); + CHECK(alignment > 0 && IsPowerOfTwo(static_cast(alignment))); + CHECK(entries_.find(name) == entries_.end()); + + int64 literal_size = ShapeUtil::ByteSizeOf(literal.shape()); + void* raw_pointer; + CHECK_EQ( + posix_memalign(&raw_pointer, std::max(alignment, sizeof(void*)), + literal_size), + 0) + << "failed to allocate " << literal_size << " bytes with alignment of " + << alignment; + + std::memcpy(raw_pointer, literal.InternalData(), literal_size); + entries_.emplace(std::move(name), static_cast(raw_pointer)); +} + +const uint8* ExternalConstantPool::Find(const string& name) { + auto it = entries_.find(name); + return it == entries_.end() ? nullptr : it->second.get(); +} +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool.h b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h new file mode 100644 index 0000000000000000000000000000000000000000..ade28cbcbcfda05a9ad0adab1139bf316720e11f --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool.h @@ -0,0 +1,64 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ + +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" + +namespace xla { +namespace cpu { +// An ExternalConstantPool maintains a set of constants kept external to +// generated LLVM IR. These constants are accessed from the IR via globals with +// extern linkage. This current incarnation of ExternalConstantPool only +// supports the JIT CPU backend; the AOT backend is not supported. +// +// Implementation-wise, this is a simple wrapper around a map of strings to byte +// buffers. This simply implementation works in a JIT scenario. This class +// will have to become smarter if we decide to support external constant pools +// on AOT compiles in the future. +class ExternalConstantPool { + public: + // Inserts a buffer with the contents of `literal` into the constant pool with + // the name `name`. It is an error to try to insert two constants with the + // same `name` into the same constant pool. The buffer for literal is aligned + // to `aligment` bytes, and `alignment` must be a power of 2. + // + // The constant pool copies out the contents of `literal` into a buffer it + // owns -- it does not keep pointers to `literal`, or to memory owned by + // `literal`. + void Insert(string name, const Literal& literal, int64 alignment); + + // Find the constant with name `name` in this constant pool. If there isn't + // such constant, return nullptr. + const uint8* Find(const string& name); + + private: + // We need to `free()` pointers allocated into `entries_` since we allocate + // them with `posix_memalign`. + struct FreeDeleter { + void operator()(void* ptr) { free(ptr); } + }; + + tensorflow::gtl::FlatMap> + entries_; +}; +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_EXTERNAL_CONSTANT_POOL_H_ diff --git a/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9290a4e5dfc03ddb86e9d82f1f0f4f9a8ceebb88 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/external_constant_pool_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace cpu { +namespace { +class ExternalConstantPoolTest : public ::testing::Test {}; + +template +T GetFromBuffer(const uint8* buffer, int64 index) { + T result; + std::memcpy(&result, buffer + index * sizeof(T), sizeof(T)); + return result; +} + +TEST(ExternalConstantPoolTest, Basic) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 2); + EXPECT_EQ(GetFromBuffer(constant, 2), 3); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); + + EXPECT_EQ(constant_pool.Find("name-1"), nullptr); +} + +TEST(ExternalConstantPoolTest, RowMinorLayout) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + const auto literal = Literal::CreateR2WithLayout( + {{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1})); + constant_pool.Insert("name-0", *literal, 4); + const uint8* constant = constant_pool.Find("name-0"); + ASSERT_NE(constant, nullptr); + + EXPECT_EQ(GetFromBuffer(constant, 0), 1); + EXPECT_EQ(GetFromBuffer(constant, 1), 3); + EXPECT_EQ(GetFromBuffer(constant, 2), 2); + EXPECT_EQ(GetFromBuffer(constant, 3), 4); +} + +TEST(ExternalConstantPoolTest, Alignment) { + ExternalConstantPool constant_pool; + EXPECT_EQ(constant_pool.Find("name-0"), nullptr); + + for (int i = 0; i < 8; i++) { + int64 alignment = 1 << i; + string name = tensorflow::strings::StrCat("name-", i); + + const auto literal = Literal::CreateR2({{1, 2}, {3, 4}}); + constant_pool.Insert(name, *literal, alignment); + + const uint8* constant = constant_pool.Find(name); + ASSERT_NE(constant, nullptr); + EXPECT_EQ(reinterpret_cast(constant) % alignment, 0); + } +} + +} // namespace +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc index 91b09f2472e4001d8df8aa1ce4dc2796af2a32e7..b99b36a55eee40bc66dcb1b7b1a464bf764ef0ea 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emission_utils.cc @@ -41,6 +41,12 @@ bool PotentiallyImplementedAsEigenConvolution( ShapeUtil::HasZeroElements(kernel_shape)) { return false; } + // TODO(b/65408531): Explore using Eigen dot for complex64 type. + if (ShapeUtil::ElementIsComplex(input_shape) || + ShapeUtil::ElementIsComplex(kernel_shape)) { + return false; + } + const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); // Only 1D and 2D convolutions are supported at the moment. @@ -55,8 +61,12 @@ bool PotentiallyImplementedAsEigenConvolution( std::is_sorted(dnums.kernel_spatial_dimensions().begin(), dnums.kernel_spatial_dimensions().end()); - return dnums.batch_dimension() == 0 && - dnums.feature_dimension() == input_shape.dimensions_size() - 1 && + const Shape& output_shape = convolution.shape(); + return dnums.input_batch_dimension() == 0 && + dnums.input_feature_dimension() == input_shape.dimensions_size() - 1 && + dnums.output_batch_dimension() == 0 && + dnums.output_feature_dimension() == + output_shape.dimensions_size() - 1 && input_spatial_dims_ascending == kernel_spatial_dims_ascending && dnums.kernel_input_feature_dimension() == kernel_shape.dimensions_size() - 2 && @@ -113,8 +123,9 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) { if (hlo.opcode() == HloOpcode::kFusion && hlo.fusion_kind() == HloInstruction::FusionKind::kTransposeDot && hlo.fused_expression_root()->opcode() == HloOpcode::kDot) { - const Shape& lhs_shape = hlo.operand(0)->shape(); - const Shape& rhs_shape = hlo.operand(1)->shape(); + auto* dot = hlo.fused_expression_root(); + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); if (ShapeUtil::HasZeroElements(lhs_shape) || ShapeUtil::HasZeroElements(rhs_shape)) { return false; diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc index 9d219a829669baf90d47c7c292188dd39b415c2b..a20ce6826ca0a86f8c0d441c1e89f091cfb434f1 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc @@ -42,6 +42,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h" #include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" #include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h" #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" @@ -49,6 +50,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" @@ -75,7 +77,8 @@ IrEmitter::IrEmitter( const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine) + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool) : assignment_(assignment), module_(llvm_module), arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()), @@ -86,7 +89,8 @@ IrEmitter::IrEmitter( parallel_cpu_backend_( options::CpuParallelBackendRequested(hlo_module_config_)), is_top_level_computation_(false), - target_machine_features_(target_machine) { + target_machine_features_(target_machine), + external_constant_pool_(external_constant_pool) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( /*fast_math_enabled=*/hlo_module_config_.debug_options() .xla_enable_fast_math())); @@ -183,20 +187,9 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { // Even though the type of params and temps is void** in the host's view, in // LLVM IR this is represented by i8*, similarly to void*. It's up to the code // to use GEPs to unravel the indirection layers. - llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); - llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); - llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); - std::vector compute_function_params( - {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); - if (IsParallelContext()) { - compute_function_params.push_back(i64_ptr_type); - } - if (hlo_to_profile_idx_) { - compute_function_params.push_back(i64_ptr_type); - } llvm::FunctionType* compute_function_type = llvm::FunctionType::get( /*Result=*/llvm::Type::getVoidTy(module_->getContext()), - /*Params=*/compute_function_params, + /*Params=*/GetComputeFunctionParams(), /*isVarArg=*/false); // Functions with local linkage get an inlining bonus. Because we know @@ -218,7 +211,7 @@ void IrEmitter::InitializeIrFunction(const string& function_name) { (++arg_iter)->setName("run_options"); (++arg_iter)->setName("params"); (++arg_iter)->setName("temps"); - if (IsParallelContext()) { + if (num_dynamic_loop_bounds_ > 0) { (++arg_iter)->setName("dynamic_loop_bounds"); } if (hlo_to_profile_idx_) { @@ -269,18 +262,42 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status IrEmitter::HandleConstant(HloInstruction* constant, - const Literal& literal) { +Status IrEmitter::HandleConstant(HloInstruction* constant) { VLOG(2) << "HandleConstant: " << constant->ToString(); - llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); - llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - /*Module=*/*module_, - /*Type=*/initializer->getType(), - /*isConstant=*/true, - /*Linkage=*/llvm::GlobalValue::PrivateLinkage, - /*Initializer=*/initializer, - /*Name=*/""); + const Literal& literal = constant->literal(); + llvm::GlobalVariable* global_for_const; + + // We avoid creating large constants in the LLVM IR since LLVM is not + // efficient for large constant arrays. We still emit "small enough" constant + // arrays into the Ir, in the off chance the LLVM optimizer can do something + // interesting with it. + const int kMaxInternalConstantSizeInBytes = 128; + if (external_constant_pool_ && + ByteSizeOf(literal.shape()) >= kMaxInternalConstantSizeInBytes) { + string global_name = tensorflow::strings::StrCat( + "constant_global_", external_global_constant_counter_++); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/IrShapeType(literal.shape()), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::ExternalLinkage, + /*Initializer=*/nullptr, + /*Name=*/AsStringRef(global_name)); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + external_constant_pool_->Insert(global_name, literal, + MinimumAlignmentForShape(literal.shape())); + } else { + llvm::Constant* initializer = + llvm_ir::ConvertLiteralToIrConstant(literal, module_); + global_for_const = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/initializer->getType(), + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/initializer, + /*Name=*/""); + global_for_const->setAlignment(MinimumAlignmentForShape(literal.shape())); + } emitted_value_[constant] = global_for_const; VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const); VLOG(2) << " its type: " @@ -291,8 +308,7 @@ Status IrEmitter::HandleConstant(HloInstruction* constant, Status IrEmitter::HandleCopy(HloInstruction* copy) { if (ShapeUtil::IsTuple(copy->shape())) { // kCopy shallow copies a tuple so just memcpy the top-level buffer. - TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy)); - emitted_value_[copy] = copy_value; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(copy)); return EmitMemcpy(*(copy->operand(0)), *copy); } else { // Use the elemental emitter for non-tuple shapes. @@ -304,17 +320,23 @@ Status IrEmitter::HandleCopy(HloInstruction* copy) { int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) { // GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on // 64-bit platforms. TCMalloc returns a pointer with alignment 8 for - // allocations smaller than 16 bytes and at least alignment 16 for allocations - // greater than or equal to 16 bytes. N.B. We could improve on this lower - // bound by explicitly allocating the memory with posix_memalign. This is + // allocations smaller than kMallocAlignmentThreshold bytes and at least + // alignment 16 for allocations greater than or equal to + // kMallocAlignmentThreshold bytes. N.B. We could improve on this lower bound + // by explicitly allocating the memory with posix_memalign. This is // complicated by our desire to allow parameter buffers created by clients to // be consumed directly by the JIT. if (buffer_size == 0) { // No need to align empty buffers. return 1; } + + const int64 kMallocAlignmentThreshold = 512; + int pointer_size = module_->getDataLayout().getPointerSize(); - int buffer_alignment = buffer_size >= 16 ? 2 * pointer_size : 8; + int buffer_alignment = buffer_size >= kMallocAlignmentThreshold + ? 2 * pointer_size + : pointer_size; DCHECK_GT(buffer_alignment, 0); return buffer_alignment; @@ -370,31 +392,30 @@ void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load, } } -Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) { +Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { // A tuple is an array of pointers, one for each operand. Each pointer points // to the output buffer of its corresponding operand. A GetTupleElement // instruction forwards a pointer to the tuple element buffer at the given // index. + auto operand = get_tuple_element->operand(0); const Shape& shape = get_tuple_element->shape(); emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement( shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape), - GetEmittedValueFor(operand), &ir_builder_); + GetEmittedValueFor(operand), &ir_builder_, module_); return Status::OK(); } -Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) { +Status IrEmitter::HandleSelect(HloInstruction* select) { + auto pred = select->operand(0); + auto on_true = select->operand(1); + auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(select)); - emitted_value_[select] = output_address; - llvm_ir::EmitTupleSelect(GetIrArrayForOp(select), GetIrArrayForOp(pred), - GetEmittedValueFor(on_true), - GetEmittedValueFor(on_false), &ir_builder_); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(select)); + llvm_ir::EmitTupleSelect( + GetIrArrayFor(select), GetIrArrayFor(pred), GetEmittedValueFor(on_true), + GetEmittedValueFor(on_false), &ir_builder_, module_); return Status::OK(); } @@ -408,8 +429,8 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { // The infeed operation produces data (dequeued from the infeed queue) at this // address, which has been provided by buffer assignment. - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(infeed)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(infeed)); + llvm_ir::IrArray infeed_array = GetIrArrayFor(infeed); if (ShapeUtil::IsTuple(shape)) { TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape)); @@ -427,9 +448,9 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { ShapeUtil::GetTupleElementShape(shape, i); // Only the outer tuple buffer's target address is obtained from - // EmitTargetAddressForOp to handle the case when Infeed is the - // root instruction. Target addresses for internal elements can - // be obtained from EmitTempBufferPointer. + // GetEmittedValueFor, to handle the case when Infeed is the root + // instruction. Target addresses for internal elements can be obtained + // from EmitTempBufferPointer. llvm::Value* tuple_element_address = EmitTempBufferPointer(buffer, tuple_element_shape); @@ -439,15 +460,13 @@ Status IrEmitter::HandleInfeed(HloInstruction* infeed) { tuple_element_addresses.push_back(tuple_element_address); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape), - tuple_element_addresses, &ir_builder_); + llvm_ir::EmitTuple(infeed_array, tuple_element_addresses, &ir_builder_, + module_); } else { - TF_RETURN_IF_ERROR( - EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address)); + TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kInfeed, shape, + GetEmittedValueFor(infeed))); } - emitted_value_[infeed] = target_address; - return Status::OK(); } @@ -545,7 +564,7 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { ShapeUtil::GetTupleElementShape(operand_shape, i); llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement( tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape), - value, &ir_builder_); + value, &ir_builder_, module_); TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed, tuple_element_shape, tuple_element)); } @@ -553,30 +572,24 @@ Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) { +Status IrEmitter::HandleSort(HloInstruction* sort) { // TODO(b/26783907): Implement sort on CPU. return Unimplemented("Sort is not supported on CPU (b/26783907)."); } -Status IrEmitter::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(tuple)); +Status IrEmitter::HandleTuple(HloInstruction* tuple) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(tuple)); std::vector base_ptrs; - for (auto operand : operands) { + for (auto operand : tuple->operands()) { base_ptrs.push_back(GetEmittedValueFor(operand)); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()), - base_ptrs, &ir_builder_); - emitted_value_[tuple] = target_address; + llvm_ir::EmitTuple(GetIrArrayFor(tuple), base_ptrs, &ir_builder_, module_); return Status::OK(); } -Status IrEmitter::HandleMap( - HloInstruction* map, tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice /*static_operands*/) { +Status IrEmitter::HandleMap(HloInstruction* map) { + tensorflow::gtl::ArraySlice operands(map->operands()); + HloComputation* function = map->to_apply(); // The called computation should have been emitted previously. llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function); @@ -584,7 +597,7 @@ Status IrEmitter::HandleMap( const llvm_ir::IrArray::Index& index) { std::vector parameter_addresses; for (const HloInstruction* operand : operands) { - const llvm_ir::IrArray& array = GetIrArrayForOp(operand); + const llvm_ir::IrArray& array = GetIrArrayFor(operand); parameter_addresses.push_back( array.EmitArrayElementAddress(index, &ir_builder_)); } @@ -593,10 +606,10 @@ Status IrEmitter::HandleMap( }); } -Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, - const Window& window, - HloComputation* function) { +Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window) { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*reduce_window, /*operands=*/{operand}, /*supported_types=*/{F32})); @@ -630,7 +643,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, // the initial value on the reduce_window. PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accumulator_address", &ir_builder_, MinimumAlignmentForPrimitiveType(operand_element_type)); ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor( @@ -680,7 +693,7 @@ Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window, SetToFirstInsertPoint(if_data.true_block, &ir_builder_); // We are not in the padding, so carry out the computation. - llvm_ir::IrArray input_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray input_array(GetIrArrayFor(operand)); llvm::Value* input_value_address = input_array.EmitArrayElementAddress(input_index, &ir_builder_); llvm::Value* result = EmitElementFunctionCall( @@ -755,7 +768,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // Allocate space to keep the currently selected value, its index, and // the boolean initialized_flag, which is initially set to false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "selected_value_address", &ir_builder_, MinimumAlignmentForPrimitiveType(operand_element_type)); llvm::Value* selected_index_address = @@ -817,7 +830,7 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { ir_builder_.CreateStore(operand_index[i], selected_index_address_slot); } }; - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); ir_builder_.CreateStore(operand_data, selected_value_address); @@ -837,8 +850,8 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. llvm::Value* cond = ir_builder_.CreateICmpNE( - result, llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + result, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); @@ -860,10 +873,10 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { selected_index.push_back( ir_builder_.CreateLoad(selected_index_address_slot)); } - llvm_ir::IrArray source_array(GetIrArrayForOp(source)); + llvm_ir::IrArray source_array(GetIrArrayFor(source)); llvm::Value* source_value_address = source_array.EmitArrayElementAddress(source_index, &ir_builder_); - llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter)); + llvm_ir::IrArray output_array(GetIrArrayFor(select_and_scatter)); llvm::Value* output_value_address = output_array.EmitArrayElementAddress(selected_index, &ir_builder_); llvm::Value* scatter_value = EmitElementFunctionCall( @@ -877,20 +890,18 @@ Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) { return Status::OK(); } -Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) { +Status IrEmitter::HandleDot(HloInstruction* dot) { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*dot, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32, F64})); + /*supported_types=*/{F32, F64, C64})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); - Shape target_shape = dot->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dot)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*dot, &target_array); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dot)); + llvm_ir::IrArray target_array = GetIrArrayFor(dot); VLOG(2) << "HandleDot: "; VLOG(2) << " lhs operand: " @@ -901,21 +912,19 @@ Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs, << llvm_ir::DumpToString(*target_array.GetBasePointer()); // Dot operation is complicated so we delegate to a helper class. - TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( + return DotOpEmitter::EmitDotOperation( *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array, lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, - hlo_module_config_)); - - emitted_value_[dot] = target_address; - return Status::OK(); + hlo_module_config_); } -Status IrEmitter::HandleConvolution(HloInstruction* convolution, - HloInstruction* lhs, HloInstruction* rhs, - const Window& window) { +Status IrEmitter::HandleConvolution(HloInstruction* convolution) { + auto lhs = convolution->operand(0); + auto rhs = convolution->operand(1); + const auto& window = convolution->window(); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( /*instruction=*/*convolution, /*operands=*/{lhs, rhs}, - /*supported_types=*/{F32})); + /*supported_types=*/{F32, C64})); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); @@ -935,21 +944,21 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, bool one_dim_convolution = lhs_shape.dimensions_size() == 3; llvm::Value* lhs_address = GetEmittedValueFor(lhs); llvm::Value* rhs_address = GetEmittedValueFor(rhs); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(convolution)); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(convolution)); const ConvolutionDimensionNumbers& dnums = convolution->convolution_dimension_numbers(); // Input tensor. const Shape& input_shape = convolution->operand(0)->shape(); - int64 input_batch = input_shape.dimensions(dnums.batch_dimension()); + int64 input_batch = input_shape.dimensions(dnums.input_batch_dimension()); int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0)); int64 input_cols = one_dim_convolution ? 1 : input_shape.dimensions(dnums.spatial_dimensions(1)); - int64 input_channels = input_shape.dimensions(dnums.feature_dimension()); + int64 input_channels = + input_shape.dimensions(dnums.input_feature_dimension()); // Kernel tensor. const Shape& kernel_shape = convolution->operand(1)->shape(); @@ -1018,35 +1027,33 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, conv_func->setDoesNotThrow(); conv_func->setOnlyAccessesArgMemory(); ir_builder_.CreateCall( - conv_func, - { - GetExecutableRunOptionsArgument(), - ir_builder_.CreateBitCast(target_address, float_ptr_type), - ir_builder_.CreateBitCast(lhs_address, float_ptr_type), - ir_builder_.CreateBitCast(rhs_address, float_ptr_type), - ir_builder_.getInt64(input_batch), - ir_builder_.getInt64(input_rows), - ir_builder_.getInt64(input_cols), - ir_builder_.getInt64(input_channels), - ir_builder_.getInt64(kernel_rows), - ir_builder_.getInt64(kernel_cols), - ir_builder_.getInt64(kernel_channels), - ir_builder_.getInt64(kernel_filters), - ir_builder_.getInt64(output_rows), - ir_builder_.getInt64(output_cols), - ir_builder_.getInt64(row_stride), - ir_builder_.getInt64(col_stride), - ir_builder_.getInt64(padding_top), - ir_builder_.getInt64(padding_bottom), - ir_builder_.getInt64(padding_left), - ir_builder_.getInt64(padding_right), - ir_builder_.getInt64(lhs_row_dilation), - ir_builder_.getInt64(lhs_col_dilation), - ir_builder_.getInt64(rhs_row_dilation), - ir_builder_.getInt64(rhs_col_dilation), - }); - target_address->setName(AsStringRef(IrName(convolution))); - emitted_value_[convolution] = target_address; + conv_func, { + GetExecutableRunOptionsArgument(), + ir_builder_.CreateBitCast( + GetEmittedValueFor(convolution), float_ptr_type), + ir_builder_.CreateBitCast(lhs_address, float_ptr_type), + ir_builder_.CreateBitCast(rhs_address, float_ptr_type), + ir_builder_.getInt64(input_batch), + ir_builder_.getInt64(input_rows), + ir_builder_.getInt64(input_cols), + ir_builder_.getInt64(input_channels), + ir_builder_.getInt64(kernel_rows), + ir_builder_.getInt64(kernel_cols), + ir_builder_.getInt64(kernel_channels), + ir_builder_.getInt64(kernel_filters), + ir_builder_.getInt64(output_rows), + ir_builder_.getInt64(output_cols), + ir_builder_.getInt64(row_stride), + ir_builder_.getInt64(col_stride), + ir_builder_.getInt64(padding_top), + ir_builder_.getInt64(padding_bottom), + ir_builder_.getInt64(padding_left), + ir_builder_.getInt64(padding_right), + ir_builder_.getInt64(lhs_row_dilation), + ir_builder_.getInt64(lhs_col_dilation), + ir_builder_.getInt64(rhs_row_dilation), + ir_builder_.getInt64(rhs_col_dilation), + }); return Status::OK(); } @@ -1066,14 +1073,14 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { output_spatial[i] = index[dnums.spatial_dimensions(i)]; } - llvm::Value* output_feature = index[dnums.feature_dimension()]; - llvm::Value* batch = index[dnums.batch_dimension()]; + llvm::Value* output_feature = index[dnums.output_feature_dimension()]; + llvm::Value* batch = index[dnums.output_batch_dimension()]; // We will accumulate the products into this sum to calculate // the output entry at the given index. PrimitiveType lhs_element_type = lhs->shape().element_type(); llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(lhs_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(lhs_element_type, module_), "convolution_sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(lhs_element_type)); ir_builder_.CreateStore( @@ -1091,8 +1098,9 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, } llvm::Value* input_feature = loops - .AddLoop(0, lhs->shape().dimensions(dnums.feature_dimension()), - "iz") + .AddLoop( + 0, lhs->shape().dimensions(dnums.input_feature_dimension()), + "iz") ->GetIndVarValue(); SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); @@ -1172,10 +1180,10 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, for (int i = 0; i < num_spatial_dims; ++i) { input_index[dnums.spatial_dimensions(i)] = input_spatial[i]; } - input_index[dnums.feature_dimension()] = input_feature; - input_index[dnums.batch_dimension()] = batch; + input_index[dnums.input_feature_dimension()] = input_feature; + input_index[dnums.input_batch_dimension()] = batch; - llvm_ir::IrArray kernel_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray kernel_array(GetIrArrayFor(rhs)); llvm_ir::IrArray::Index kernel_index(num_dims); for (int i = 0; i < num_spatial_dims; ++i) { kernel_index[dnums.kernel_spatial_dimensions(i)] = kernel_spatial[i]; @@ -1183,7 +1191,7 @@ Status IrEmitter::HandleConvolution(HloInstruction* convolution, kernel_index[dnums.kernel_input_feature_dimension()] = input_feature; kernel_index[dnums.kernel_output_feature_dimension()] = output_feature; - llvm_ir::IrArray input_array(GetIrArrayForOp(lhs)); + llvm_ir::IrArray input_array(GetIrArrayFor(lhs)); llvm::Value* product = ir_builder_.CreateFMul( input_array.EmitReadArrayElement(input_index, &ir_builder_), kernel_array.EmitReadArrayElement(kernel_index, &ir_builder_)); @@ -1288,14 +1296,14 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { PrimitiveType element_type = operand->shape().element_type(); // Used to calculate E(X). llvm::Value* sum_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), "sum_address", &ir_builder_, MinimumAlignmentForPrimitiveType(element_type)); // Used to calculate E(X^2). llvm::Value* sum_square_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(element_type, module_), "sum_square_address", &ir_builder_, MinimumAlignmentForPrimitiveType(element_type)); @@ -1317,7 +1325,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm_ir::IrArray::Index input_index = FillReducedDimensionIndex(reduced_dims_index, index); llvm::Value* new_value = @@ -1361,9 +1369,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { mean_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "mean_var"))); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(batch_norm_training)); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(batch_norm_training)); TF_ASSIGN_OR_RETURN( const BufferAllocation::Slice slice, assignment_.GetUniqueSlice(batch_norm_training, /*index=*/{0})); @@ -1393,7 +1399,7 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { llvm::Value* var = var_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* input = operand_array.EmitReadArrayElement(index, &ir_builder_); @@ -1405,10 +1411,10 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { ir_builder_.CreateCall(func_llvm_sqrt, {variance_with_epsilon}); llvm::Value* normalized = ir_builder_.CreateFDiv( ir_builder_.CreateFSub(input, mean), variance_sqrt); - llvm_ir::IrArray offset_array(GetIrArrayForOp(offset)); + llvm_ir::IrArray offset_array(GetIrArrayFor(offset)); llvm::Value* offset = offset_array.EmitReadArrayElement( feature_index_value, &ir_builder_); - llvm_ir::IrArray scale_array(GetIrArrayForOp(scale)); + llvm_ir::IrArray scale_array(GetIrArrayFor(scale)); llvm::Value* scale = scale_array.EmitReadArrayElement( feature_index_value, &ir_builder_); llvm::Value* result = ir_builder_.CreateFAdd( @@ -1419,11 +1425,8 @@ Status IrEmitter::HandleBatchNormTraining(HloInstruction* batch_norm_training) { target_array, &ir_builder_) .EmitLoop(IrName(batch_norm_training, "normalize"))); - llvm_ir::EmitTuple( - llvm_ir::IrArray(target_address, batch_norm_training->shape()), - {normalized, mean, var}, &ir_builder_); - emitted_value_[batch_norm_training] = target_address; - + llvm_ir::EmitTuple(GetIrArrayFor(batch_norm_training), + {normalized, mean, var}, &ir_builder_, module_); return Status::OK(); } @@ -1451,13 +1454,19 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_); llvm::LoadInst* param_address_untyped = ir_builder_.CreateLoad(param_address_offset); + param_address_untyped->setName(AsStringRef(IrName(parameter, "untyped"))); + if (hlo_module_config_.debug_options() + .xla_llvm_enable_invariant_load_metadata()) { + // We never reassign parameters, so this load is invariant. + param_address_untyped->setMetadata( + llvm::LLVMContext::MD_invariant_load, + llvm::MDNode::get(param_address_untyped->getContext(), /*MDs=*/{})); + } + llvm::Value* param_address_typed = ir_builder_.CreateBitCast( param_address_untyped, IrShapeType(param_shape)->getPointerTo()); emitted_value_[parameter] = param_address_typed; - // Parameters of different types may not alias one another. - llvm_ir::SetTbaaForInstruction(param_address_untyped, param_shape, - /*is_pointer_to=*/true); if (!ShapeUtil::IsOpaque(param_shape)) { AttachAlignmentMetadataForLoad(param_address_untyped, param_shape); AttachDereferenceableMetadataForLoad(param_address_untyped, param_shape); @@ -1480,6 +1489,14 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( } const Shape& root_shape = root_instruction->shape(); + if (ShapeUtil::ElementIsComplex(root_shape)) { + // TODO(b/65408531): Complex add could by done via bitcast to + // Complex multiply would be more challenging. We could perhaps use a + // strided load to get all reals in a vector, all imags in a vector, or use + // CreateShuffleVector on a bitcast to float x [2N]. + *failure_reason = "complex values not supported"; + return nullptr; + } bool root_is_floating_point = ShapeUtil::ElementIsFloating(root_shape); bool root_is_integral = ShapeUtil::ElementIsIntegral(root_shape); bool root_is_signed = ShapeUtil::ElementIsSigned(root_shape); @@ -1501,7 +1518,7 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( // This is visually similar to ElementalIrEmitter, though conceptually we're // doing something different here. ElementalIrEmitter emits scalar operations // while these emit scalar or vector operations depending on the type of the - // operands. + // operands. See CreateShardedVectorType for the actual types in use here. switch (root_instruction->opcode()) { default: *failure_reason = "did not recognize root instruction opcode"; @@ -1521,11 +1538,11 @@ IrEmitter::ReductionGenerator IrEmitter::MatchReductionGenerator( : ir_builder->CreateFMul(lhs, rhs); }; - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateAnd(lhs, rhs); }; - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return [](llvm::IRBuilder<>* ir_builder, llvm::Value* lhs, llvm::Value* rhs) { return ir_builder->CreateOr(lhs, rhs); }; @@ -1578,7 +1595,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType( ShardedVectorType sharded_vector_type; llvm::Type* element_ir_type = - llvm_ir::PrimitiveTypeToIrType(element_type, &ir_builder_); + llvm_ir::PrimitiveTypeToIrType(element_type, module_); for (int i = 0, e = 1 + tensorflow::Log2Ceiling(element_count); i < e; i++) { // For every power of two present in element_count, we generate one or more @@ -1661,7 +1678,7 @@ IrEmitter::EmitInnerLoopForVectorizedReduction( SetToFirstInsertPoint(reduction_loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_); - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = output_index.begin(); @@ -1774,6 +1791,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( } CHECK(!ShapeUtil::IsTuple(reduce->shape())); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(reduce)); // We know we're not reducing over the most minor dimension, which means we // can lower the reduction loop as: @@ -1836,10 +1854,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1871,10 +1886,7 @@ StatusOr IrEmitter::EmitVectorizedReduce( reduction_generator, array_index, vector_type, init_value, arg, dimensions, element_alignment)); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - llvm_ir::IrArray target_array(target_address, reduce->shape()); - AddAliasingInformationToIrArray(*reduce, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(reduce); llvm::Value* output_address = target_array.EmitArrayElementAddress(array_index, &ir_builder_); EmitShardedVectorStore(output_address, accumulator, element_alignment, @@ -1885,17 +1897,14 @@ StatusOr IrEmitter::EmitVectorizedReduce( ir_builder_.SetInsertPoint(outermost_loop_exit_block); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(reduce)); - - emitted_value_[reduce] = target_address; return true; } -Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) { +Status IrEmitter::HandleReduce(HloInstruction* reduce) { + auto arg = reduce->mutable_operand(0); + auto init_value = reduce->mutable_operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); if (!options::VectorizedReduceDisabled(hlo_module_config_)) { string vectorization_failure_reason; TF_ASSIGN_OR_RETURN( @@ -1920,7 +1929,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // Initialize an accumulator with init_value. PrimitiveType accumulator_type = reduce->shape().element_type(); llvm::AllocaInst* accumulator_addr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(accumulator_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(accumulator_type, module_), "accumulator", &ir_builder_, MinimumAlignmentForPrimitiveType(accumulator_type)); llvm::Value* init_value_addr = GetEmittedValueFor(init_value); @@ -1945,7 +1954,7 @@ Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, // filled in. We fill in the rest of the dimensions with induction // Value*s taken from 'index' which iterates over the target array. // See the high-level description in the XLA documentation for details. - llvm_ir::IrArray arg_array(GetIrArrayForOp(arg)); + llvm_ir::IrArray arg_array(GetIrArrayFor(arg)); llvm_ir::IrArray::Index input_index = reduced_dims_index; llvm_ir::IrArray::Index::const_iterator it = index.begin(); @@ -1974,9 +1983,9 @@ Status IrEmitter::HandleSend(HloInstruction* send) { return Unimplemented("Send is not implemented on CPU. See b/33942983."); } -Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { +Status IrEmitter::HandleSlice(HloInstruction* slice) { VLOG(2) << "HandleSlice: " << slice->ToString(); - + auto operand = slice->operand(0); // The code below emits a sequential loop nest. For the parallel backend, use // EmitParallelTargetElementLoop() which respects dynamic loop bounds. if (ShouldEmitParallelLoopFor(*slice)) { @@ -1988,9 +1997,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { return DefaultAction(slice); } - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(slice)); - emitted_value_[slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(slice)); if (ShapeUtil::HasZeroElements(slice->shape())) { return Status::OK(); @@ -2062,8 +2069,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { outer_dims.push_back(memcpy_dim); } - llvm_ir::IrArray target_array(target_address, slice->shape()); - AddAliasingInformationToIrArray(*slice, &target_array); + llvm_ir::IrArray target_array = GetIrArrayFor(slice); const int64 num_outer_loops = outer_dims.size(); llvm_ir::ForLoopNest loops(IrName(slice), &ir_builder_); @@ -2081,7 +2087,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); } - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); const llvm_ir::IrArray::Index source_index = target_index.SourceIndexOfSlice( /*shape=*/slice->shape(), /*starts=*/slice->slice_starts(), /*strides=*/slice->slice_strides(), /*builder=*/&ir_builder_); @@ -2112,130 +2118,27 @@ Status IrEmitter::HandleSlice(HloInstruction* slice, HloInstruction* operand) { return Status::OK(); } -Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* /*start_indices*/) { +Status IrEmitter::HandleDynamicSlice(HloInstruction* dynamic_slice) { if (ShapeUtil::IsScalar(dynamic_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_slice)); - target_address->setName(AsStringRef(IrName(dynamic_slice))); - emitted_value_[dynamic_slice] = target_address; - return EmitMemcpy(*operand, *dynamic_slice); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_slice)); + return EmitMemcpy(*dynamic_slice->operand(0), *dynamic_slice); } return DefaultAction(dynamic_slice); } -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; -} - -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -// TODO(b/64142684) Share code with GPU implementation. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* dynamic_update_slice) { - CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); - - // Walk DynamicUpdateSlice operand(0) to parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* operand = - LatestNonGteAncestorAndIndex(dynamic_update_slice->operand(0), &index); - if (operand->opcode() != HloOpcode::kParameter) { - return false; - } - - BufferAllocation::Slice operand_slice = - assignment.GetUniqueSlice(operand, index).ConsumeValueOrDie(); - - BufferAllocation::Slice dynamic_update_slice_slice = - assignment.GetUniqueTopLevelSlice(dynamic_update_slice) - .ConsumeValueOrDie(); - - return operand_slice == dynamic_update_slice_slice; -} - -} // namespace - -Status IrEmitter::HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) { +Status IrEmitter::HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) { + auto update = dynamic_update_slice->operand(1); if (ShapeUtil::IsScalar(dynamic_update_slice->shape())) { - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(dynamic_update_slice)); - target_address->setName(AsStringRef(IrName(dynamic_update_slice))); - emitted_value_[dynamic_update_slice] = target_address; + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); return EmitMemcpy(*update, *dynamic_update_slice); - } else if (CanUpdateDynamicSliceInPlace(assignment_, dynamic_update_slice)) { - VLOG(2) << "Emitting HandleDynamicUpdateSlice in-place."; - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. - // TODO(b/64142684) Implement in-place update for fused DynamicUpdateSlice. - - // Emit IR to read dynamic start indices from 'start_indices'. - const int64 rank = ShapeUtil::Rank(operand->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - llvm_ir::IrArray start_indices_array(GetIrArrayForOp(start_indices)); - start_index[i] = - start_indices_array.EmitReadArrayElement(dim_index, &ir_builder_); - } - - // Create loop body emitter which emits code to do the following: - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update'. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [&](const llvm_ir::IrArray::Index& index) -> Status { - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), operand->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } - - // Read value from 'update'. - llvm_ir::IrArray update_array(GetIrArrayForOp(update)); - llvm::Value* update_data = - update_array.EmitReadArrayElement(index, &ir_builder_); - - // Write value to output array. - GetIrArrayForOp(operand).EmitWriteArrayElement(output_index, update_data, - &ir_builder_); - return Status::OK(); - }; - - TF_RETURN_IF_ERROR( - llvm_ir::LoopEmitter(loop_body_emitter, update->shape(), &ir_builder_) - .EmitLoop(IrName(dynamic_update_slice, "in_place"))); - - TF_ASSIGN_OR_RETURN(llvm::Value * dynamic_update_slice_address, - EmitTargetAddressForOp(dynamic_update_slice)); - emitted_value_[dynamic_update_slice] = dynamic_update_slice_address; - return Status::OK(); + } else if (llvm_ir::CanUpdateDynamicSliceInPlace(dynamic_update_slice, + assignment_)) { + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(dynamic_update_slice)); + auto operands = GetIrArraysForOperandsOf(dynamic_update_slice); + return llvm_ir::EmitDynamicUpdateSliceInPlace( + operands, GetIrArrayFor(dynamic_update_slice), + IrName(dynamic_update_slice, "in_place"), &ir_builder_); } return DefaultAction(dynamic_update_slice); } @@ -2277,7 +2180,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_); // Load an element from the operand. - llvm_ir::IrArray operand_array(GetIrArrayForOp(operand)); + llvm_ir::IrArray operand_array(GetIrArrayFor(operand)); llvm::Value* operand_data = operand_array.EmitReadArrayElement(operand_index, &ir_builder_); @@ -2297,7 +2200,7 @@ Status IrEmitter::HandlePad(HloInstruction* pad) { } // Store the operand element to the computed output location. - llvm_ir::IrArray output_array(GetIrArrayForOp(pad)); + llvm_ir::IrArray output_array(GetIrArrayFor(pad)); output_array.EmitWriteArrayElement(output_index, operand_data, &ir_builder_); SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); @@ -2313,11 +2216,11 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) { } Status IrEmitter::HandleFusion(HloInstruction* fusion) { + auto* root = fusion->fused_expression_root(); if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) { - const HloInstruction* dot = fusion->fused_expression_root(); - DCHECK(dot->opcode() == HloOpcode::kDot); - const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0)); - const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1)); + DCHECK(root->opcode() == HloOpcode::kDot); + const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0)); + const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1)); DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter && rhs_parameter->opcode() == HloOpcode::kParameter); const HloInstruction* lhs = @@ -2326,18 +2229,15 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { fusion->operand(rhs_parameter->parameter_number()); TF_RETURN_IF_ERROR(ElementTypesSameAndSupported( - /*instruction=*/*dot, /*operands=*/{lhs, rhs}, + /*instruction=*/*root, /*operands=*/{lhs, rhs}, /*supported_types=*/{F32})); - llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs)); - llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs)); + llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs)); + llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs)); Shape target_shape = fusion->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(fusion)); - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*fusion, &target_array); - + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + llvm_ir::IrArray target_array = GetIrArrayFor(fusion); VLOG(2) << "HandleFusion kTransposeDot: "; VLOG(2) << " lhs operand: " << llvm_ir::DumpToString(*lhs_array.GetBasePointer()); @@ -2348,19 +2248,27 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { // Dot operation is complicated so we delegate to a helper class. TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation( - *dot, dot->operand(0)->IsRank2Transpose(), - dot->operand(1)->IsRank2Transpose(), target_array, lhs_array, rhs_array, - GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_)); - - emitted_value_[fusion] = target_address; + *root, root->operand(0)->IsRank2Transpose(), + root->operand(1)->IsRank2Transpose(), target_array, lhs_array, + rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_, + hlo_module_config_)); return Status::OK(); + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, + assignment_)) { + VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace"; + CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion)); + + // Delegate to common implementation of fused in-place dynamic-update-slice. + auto operands = GetIrArraysForOperandsOf(fusion); + return llvm_ir::EmitFusedDynamicUpdateSliceInPlace( + fusion, operands, GetIrArrayFor(fusion), &elemental_emitter, + &ir_builder_); } else if (fusion->fusion_kind() == HloInstruction::FusionKind::kLoop) { - std::vector parameter_arrays; - for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArrayForOp(operand)); - } + VLOG(3) << "HandleFusion kLoop"; CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); + auto operands = GetIrArraysForOperandsOf(fusion); + FusedIrEmitter fused_emitter(operands, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); return EmitTargetElementLoop(fusion, fused_emitter.GetRootGenerator()); @@ -2378,21 +2286,27 @@ Status IrEmitter::HandleCall(HloInstruction* call) { parameter_addresses.push_back(GetEmittedValueFor(operand)); } - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(call)); - output_address->setName(AsStringRef(IrName(call))); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call)); - EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, - output_address, computation->name()); + if (!computation->root_instruction()->outer_dimension_partitions().empty() && + !parallel_cpu_backend_) { + // ParallelTaskAssignment assigned partitions, emit call to + // ParallelForkJoin. + TF_RETURN_IF_ERROR(EmitParallelForkJoin(parameter_addresses, + emitted_value_[call], computation, + call_ir_function)); + } else { + EmitArrayFunctionCallInto(call_ir_function, parameter_addresses, + emitted_value_[call], computation->name()); + } - emitted_value_[call] = output_address; return Status::OK(); } -Status IrEmitter::HandleCustomCall( - HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { +Status IrEmitter::HandleCustomCall(HloInstruction* custom_call) { + tensorflow::gtl::ArraySlice operands( + custom_call->operands()); + tensorflow::StringPiece custom_call_target(custom_call->custom_call_target()); llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::AllocaInst* operands_alloca = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -2414,17 +2328,13 @@ Status IrEmitter::HandleCustomCall( /*Params=*/{i8_ptr_type, operands_alloca->getType()}, /*isVarArg=*/false))); - TF_ASSIGN_OR_RETURN(llvm::Value * output_address, - EmitTargetAddressForOp(custom_call)); - output_address->setName(AsStringRef(IrName(custom_call))); - - auto* output_address_arg = - ir_builder_.CreatePointerCast(output_address, i8_ptr_type); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(custom_call)); + auto* output_address_arg = ir_builder_.CreatePointerCast( + GetEmittedValueFor(custom_call), i8_ptr_type); ir_builder_.CreateCall(custom_call_ir_function, {output_address_arg, operands_alloca}); - emitted_value_[custom_call] = output_address; return Status::OK(); } @@ -2499,8 +2409,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) { {while_result}, IrName(xla_while, "cond")); llvm::Value* while_predicate = ir_builder_.CreateICmpNE( while_condition, - llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), - 0)); + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(PRED, module_), 0)); // Branches to the body or to the while exit depending on the condition. llvm::BasicBlock* body_bb = llvm::BasicBlock::Create( @@ -2568,10 +2477,8 @@ StatusOr IrEmitter::EmitFastConcatenate( llvm::Type* i8_ptr_type = ir_builder_.getInt8PtrTy(); llvm::Type* i8_type = ir_builder_.getInt8Ty(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(concatenate)); - - llvm_ir::IrArray target_array(target_address, output_shape); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(concatenate)); + llvm_ir::IrArray target_array = GetIrArrayFor(concatenate); llvm_ir::ForLoopNest loops(IrName(concatenate), &ir_builder_); llvm_ir::IrArray::Index outer_dims_index = @@ -2588,8 +2495,6 @@ StatusOr IrEmitter::EmitFastConcatenate( unsigned primitive_type_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type); - AddAliasingInformationToIrArray(*concatenate, &target_array); - // Contiguous subregions from each operand to the concatenate contribute to a // contiguous subregion in the target buffer starting at target_region_begin. llvm::Value* target_region_begin = ir_builder_.CreateBitCast( @@ -2608,7 +2513,7 @@ StatusOr IrEmitter::EmitFastConcatenate( // equal to the product of inner dimensions. for (HloInstruction* operand : operands) { const Shape& input_shape = operand->shape(); - llvm_ir::IrArray source_array = GetIrArrayForOp(operand); + llvm_ir::IrArray source_array = GetIrArrayFor(operand); llvm::Value* copy_source_address = ir_builder_.CreateBitCast( source_array.EmitArrayElementAddress(outer_dims_index, &ir_builder_, "src_addr"), @@ -2632,8 +2537,6 @@ StatusOr IrEmitter::EmitFastConcatenate( SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_); } - emitted_value_[concatenate] = target_address; - return true; } @@ -2647,7 +2550,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, unsigned element_alignment = GCD( primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type)); llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual( - llvm_ir::PrimitiveTypeToIrType(primitive_type, &ir_builder_)); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_)); if (element_count == 1) { auto* load_instruction = ir_builder_.CreateAlignedLoad( @@ -2673,9 +2576,9 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source, } } -Status IrEmitter::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { +Status IrEmitter::HandleConcatenate(HloInstruction* concatenate) { + tensorflow::gtl::ArraySlice operands( + concatenate->operands()); string failure_reason; TF_ASSIGN_OR_RETURN( bool successful, @@ -2705,14 +2608,14 @@ Status IrEmitter::FinishVisit(HloInstruction* root) { // For the parallel cpu backend, we record the total for each embedded // computation callee with its caller kCall HLO. HloInstruction* hlo_to_lookup = nullptr; - if (IsParallelContext()) { + if (parallel_cpu_backend_ && is_top_level_computation_) { auto* computation = root->parent(); auto* entry_computation = computation->parent()->entry_computation(); if (computation != entry_computation) { - for (auto& instruction : entry_computation->instructions()) { + for (HloInstruction* instruction : entry_computation->instructions()) { if (instruction->opcode() == HloOpcode::kCall && instruction->to_apply()->root_instruction() == root) { - hlo_to_lookup = instruction.get(); + hlo_to_lookup = instruction; break; } } @@ -2833,7 +2736,7 @@ Status IrEmitter::Postprocess(HloInstruction* hlo) { return Status::OK(); } -llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { +llvm_ir::IrArray IrEmitter::GetIrArrayFor(const HloInstruction* hlo) { llvm::Value* value_for_op = GetEmittedValueFor(hlo); llvm_ir::IrArray array(value_for_op, hlo->shape()); @@ -2841,6 +2744,16 @@ llvm_ir::IrArray IrEmitter::GetIrArrayForOp(const HloInstruction* hlo) { return array; } +std::vector IrEmitter::GetIrArraysForOperandsOf( + const HloInstruction* hlo) { + std::vector arrays; + std::transform( + hlo->operands().begin(), hlo->operands().end(), + std::back_inserter(arrays), + [&](const HloInstruction* operand) { return GetIrArrayFor(operand); }); + return arrays; +} + llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { auto it = emitted_value_.find(hlo); if (it == emitted_value_.end()) { @@ -2850,7 +2763,22 @@ llvm::Value* IrEmitter::GetEmittedValueFor(const HloInstruction* hlo) { } llvm::Type* IrEmitter::IrShapeType(const Shape& shape) { - return llvm_ir::ShapeToIrType(shape, &ir_builder_); + return llvm_ir::ShapeToIrType(shape, module_); +} + +std::vector IrEmitter::GetComputeFunctionParams() { + llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext()); + llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo(); + llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext()); + std::vector compute_function_params( + {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type}); + if (num_dynamic_loop_bounds_ > 0) { + compute_function_params.push_back(i64_ptr_type); + } + if (hlo_to_profile_idx_) { + compute_function_params.push_back(i64_ptr_type); + } + return compute_function_params; } llvm::Argument* IrEmitter::GetResultArgument() { @@ -2858,7 +2786,7 @@ llvm::Argument* IrEmitter::GetResultArgument() { } llvm::Argument* IrEmitter::GetProfileCountersArgument() { - const int64 arg_index = IsParallelContext() ? 5 : 4; + const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4; return hlo_to_profile_idx_ ? GetArg(compute_function_, arg_index) : nullptr; } @@ -2909,14 +2837,12 @@ llvm::Value* IrEmitter::EmitTempBufferPointer( ir_builder_.CreateLoad(tempbuf_address_ptr); if (hlo_module_config_.debug_options() .xla_llvm_enable_invariant_load_metadata()) { - // Loading the address of a buffer is invariant of the point at which the - // load is executed in the program because we never reassign buffers. + // Loading the address of a buffer is invariant of the point at which the + // load is executed in the program because we never reassign buffers. tempbuf_address_base->setMetadata( llvm::LLVMContext::MD_invariant_load, llvm::MDNode::get(tempbuf_address_base->getContext(), /*MDs=*/{})); } - llvm_ir::SetTbaaForInstruction(tempbuf_address_base, target_shape, - /*is_pointer_to=*/true); AttachAlignmentMetadataForLoad(tempbuf_address_base, allocation.size()); AttachDereferenceableMetadataForLoad(tempbuf_address_base, allocation.size()); @@ -2943,18 +2869,11 @@ llvm::Value* IrEmitter::EmitElementFunctionCall( AsStringRef(tensorflow::strings::StrCat(name, "_return_value"))); } -// Emits a core function call based on the following pseudo-code. -// -// char** parameter_addresses_buffer = -// allocate buffer with a pointer for each parameter to the function -// for each parameter index, i.e. for i = 0, ..., #parameters: -// parameter_addresses_buffer[i] = parameter_addresses[i] -// call function(return_value_buffer, -// parameter_addresses_buffer, -// temps) -// return return_value_buffer -- address of the return value. -void IrEmitter::EmitArrayFunctionCallInto( - llvm::Function* function, +// Emits code to allocate an array of parameter address pointers, and store +// each address from 'parameter_addresses'. +// Returns an array of compute function call arguments (including parameter +// address buffer). +std::vector IrEmitter::GetArrayFunctionCallArguments( tensorflow::gtl::ArraySlice parameter_addresses, llvm::Value* return_value_buffer, tensorflow::StringPiece name) { llvm::Value* parameter_addresses_buffer = @@ -2983,7 +2902,26 @@ void IrEmitter::EmitArrayFunctionCallInto( if (auto* profile_counters = GetProfileCountersArgument()) { arguments.push_back(profile_counters); } - ir_builder_.CreateCall(function, arguments); + return arguments; +} + +// Emits a core function call based on the following pseudo-code. +// +// char** parameter_addresses_buffer = +// allocate buffer with a pointer for each parameter to the function +// for each parameter index, i.e. for i = 0, ..., #parameters: +// parameter_addresses_buffer[i] = parameter_addresses[i] +// call function(return_value_buffer, +// parameter_addresses_buffer, +// temps) +// return return_value_buffer -- address of the return value. +void IrEmitter::EmitArrayFunctionCallInto( + llvm::Function* function, + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name) { + ir_builder_.CreateCall( + function, GetArrayFunctionCallArguments(parameter_addresses, + return_value_buffer, name)); } llvm::Value* IrEmitter::EmitArrayFunctionCall( @@ -2995,7 +2933,7 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( PrimitiveType return_type = return_shape.element_type(); llvm::Value* return_value_buffer = llvm_ir::EmitAllocaAtFunctionEntryWithCount( - llvm_ir::PrimitiveTypeToIrType(return_type, &ir_builder_), elements, + llvm_ir::PrimitiveTypeToIrType(return_type, module_), elements, tensorflow::strings::StrCat(name, "_return_value_address"), &ir_builder_, MinimumAlignmentForPrimitiveType(return_type)); EmitArrayFunctionCallInto(function, parameter_addresses, return_value_buffer, @@ -3003,10 +2941,114 @@ llvm::Value* IrEmitter::EmitArrayFunctionCall( return return_value_buffer; } -StatusOr IrEmitter::EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index) { - const Shape& target_shape = ShapeUtil::GetSubshape(op->shape(), shape_index); - if (op == op->parent()->root_instruction() && shape_index.empty()) { +// Emits a call to a runtime fork/join function which dispatches parallel +// calls to 'parallel_function' (and joins threads before returning). +Status IrEmitter::EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function) { + HloInstruction* root = computation->root_instruction(); + + // Build ParallelForkJoin function type. + std::vector compute_function_params = GetComputeFunctionParams(); + // Number of parallel compute functions. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Array of partitions. There is an array element for each + // partition x partition_dim x 2 (for dimension start and limit). + compute_function_params.push_back( + llvm::Type::getInt64PtrTy(module_->getContext())); + // Number of partitioned most-major dimensions in 'root.shape'. + compute_function_params.push_back(ir_builder_.getInt32Ty()); + // Function pointer for compute function to be dispatched in parallel. + compute_function_params.push_back( + llvm::Type::getInt8PtrTy(module_->getContext())); + + llvm::FunctionType* fork_join_type = llvm::FunctionType::get( + /*Result=*/llvm::Type::getVoidTy(module_->getContext()), + /*Params=*/compute_function_params, + /*isVarArg=*/false); + + llvm::Function* fork_join_func = + llvm::cast(module_->getOrInsertFunction( + runtime::kParallelForkJoinSymbolName, fork_join_type)); + fork_join_func->setCallingConv(llvm::CallingConv::C); + fork_join_func->setDoesNotThrow(); + + // Add common compute function arguments. + const string name = computation->name(); + std::vector arguments = + GetArrayFunctionCallArguments(parameter_addresses, output_address, name); + + // Create ShapePartitionIterator to generate all partitions of 'root.shape'. + ShapePartitionIterator partition_iterator(root->shape(), + root->outer_dimension_partitions()); + const int64 num_partitions = partition_iterator.GetTotalPartitionCount(); + // Add argument specifying the number of parallel partitions. + arguments.push_back(ir_builder_.getInt32(num_partitions)); + + // The number of partitioned most-major dimensions in 'root.shape'. + const int32 num_partitioned_dims = root->outer_dimension_partitions().size(); + // A dimension partition consists of two elements: [start_index, limit_index). + const int32 dim_partition_size = 2; + // Calculate array partition stride. + const int32 array_partition_stride = + num_partitioned_dims * dim_partition_size; + // Calculate the total number of elements in the partition array. + const int32 partition_array_size = + dim_partition_size * num_partitioned_dims * num_partitions; + + // Store dimension partition values as llvm constants in 'partitions'. + // See comments in runtime_fork_join.cc for array layout description. + std::vector partitions(partition_array_size); + for (int32 i = 0; i < num_partitions; ++i) { + std::vector> dim_partitions = + partition_iterator.GetPartition(i); + CHECK_EQ(num_partitioned_dims, dim_partitions.size()); + const int32 partition_index = i * array_partition_stride; + for (int32 j = 0; j < num_partitioned_dims; ++j) { + const std::pair& dim_partition = dim_partitions[j]; + const int32 index = partition_index + j * dim_partition_size; + // Store partition [dim_start, dim_limit) intervals for each dimension. + partitions[index] = ir_builder_.getInt64(dim_partition.first); + partitions[index + 1] = + ir_builder_.getInt64(dim_partition.first + dim_partition.second); + } + } + + // Create global variable out of dimension partitions in 'partitions'. + llvm::ArrayType* partitions_array_type = + llvm::ArrayType::get(ir_builder_.getInt64Ty(), partition_array_size); + llvm::Constant* partitions_array = + llvm::ConstantArray::get(partitions_array_type, partitions); + llvm::GlobalVariable* global_partitions_array = new llvm::GlobalVariable( + /*Module=*/*module_, + /*Type=*/partitions_array_type, + /*isConstant=*/true, + /*Linkage=*/llvm::GlobalValue::PrivateLinkage, + /*Initializer=*/partitions_array, + /*Name=*/ + AsStringRef( + tensorflow::strings::StrCat(name, "_parallel_dimension_partitions"))); + + // Add argument specifying parallel dimension partitions. + arguments.push_back(ir_builder_.CreateBitCast( + global_partitions_array, + llvm::Type::getInt64PtrTy(module_->getContext()))); + // Add argument specifying the number of partitioned most-major dimensions. + arguments.push_back(ir_builder_.getInt32(num_partitioned_dims)); + // Add argument for parallel compute function pointer. + arguments.push_back( + ir_builder_.CreateBitCast(parallel_function, ir_builder_.getInt8PtrTy())); + // Emit call to parallel fork/join. + ir_builder_.CreateCall(fork_join_func, arguments); + + return Status::OK(); +} + +Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) { + llvm::Value* addr; + const Shape& target_shape = op->shape(); + if (op == op->parent()->root_instruction()) { // For the root node, we write directly to the output buffer of the // function. llvm::Argument* retval = GetResultArgument(); @@ -3016,15 +3058,18 @@ StatusOr IrEmitter::EmitTargetAddressForOp( attr_builder.addDereferenceableAttr(ByteSizeOf(target_shape)); retval->addAttrs(attr_builder); } - return ir_builder_.CreateBitCast(retval, + addr = ir_builder_.CreateBitCast(retval, IrShapeType(target_shape)->getPointerTo()); - } - - // For other nodes, we need the temporary buffer allocated for this node to - // write the result into. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - assignment_.GetUniqueTopLevelSlice(op)); - return EmitTempBufferPointer(slice, target_shape); + } else { + // For other nodes, we need the temporary buffer allocated for this node to + // write the result into. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + assignment_.GetUniqueTopLevelSlice(op)); + addr = EmitTempBufferPointer(slice, target_shape); + } + addr->setName(AsStringRef(IrName(op))); + emitted_value_[op] = addr; + return Status::OK(); } Status IrEmitter::EmitTargetElementLoop( @@ -3038,12 +3083,9 @@ Status IrEmitter::EmitTargetElementLoop( const llvm_ir::ElementGenerator& element_generator) { VLOG(2) << "EmitTargetElementLoop: " << target_op->ToString(); - // target_address will hold the address of the target buffer we will write the - // result of the computation into. const Shape& target_shape = target_op->shape(); - TF_ASSIGN_OR_RETURN(llvm::Value * target_address, - EmitTargetAddressForOp(target_op)); - VLOG(2) << " target address: " << llvm_ir::DumpToString(*target_address); + TF_RETURN_IF_ERROR(EmitTargetAddressForOp(target_op)); + llvm_ir::IrArray target_array = GetIrArrayFor(target_op); if (target_op->IsMultiOutputFusion()) { // For multiple outputs fusion, we need to emit each operand and the root. @@ -3066,13 +3108,9 @@ Status IrEmitter::EmitTargetElementLoop( for (int64 i = 0; i < output_arrays.size(); ++i) { tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } - llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, target_shape), - tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(target_array, tuple_operand_ptrs, &ir_builder_, module_); } else { - llvm_ir::IrArray target_array(target_address, target_shape); - AddAliasingInformationToIrArray(*target_op, &target_array); - if (ShouldEmitParallelLoopFor(*target_op)) { TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop( target_shape, element_generator, IrName(target_op), &target_array)); @@ -3082,8 +3120,6 @@ Status IrEmitter::EmitTargetElementLoop( .EmitLoop(IrName(target_op))); } } - - emitted_value_[target_op] = target_address; return Status::OK(); } @@ -3175,7 +3211,7 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : hlo->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { - return GetIrArrayForOp(operand).EmitReadArrayElement(index, &ir_builder_); + return GetIrArrayFor(operand).EmitReadArrayElement(index, &ir_builder_); }; } CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_); diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h index 8042e03e69561aeacccc5498eaf52f32bbd78b62..5d061e11e3c9e07bdcfdc749711e4369ec2bea2a 100644 --- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/IR/Value.h" #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -104,11 +105,15 @@ class IrEmitter : public DfsHloVisitorWithDefault { // llvm_module: the LLVM module to emit IR into. // hlo_to_profile_idx: the mapping from HLO to its index in the profiling // array. + // external_constant_pool: if non-null, points to an ExternalConstantPool + // instance into which the Ir emitter can spill + // constants. IrEmitter(const HloModule& hlo_module, const BufferAssignment& assignment, llvm::Module* llvm_module, const std::unordered_map* hlo_to_profile_idx, - llvm::TargetMachine* target_machine); + llvm::TargetMachine* target_machine, + ExternalConstantPool* external_constant_pool); ~IrEmitter() override; // Emit and return the given HLO computation as an LLVM IR @@ -146,69 +151,43 @@ class IrEmitter : public DfsHloVisitorWithDefault { // // Default action which emits code for most operations. Operations which are // special in some way are handled explicitly in HandleFoo methods. - Status DefaultAction(HloInstruction* hlo_instruction) override; + Status DefaultAction(HloInstruction* hlo) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; + Status HandleConstant(HloInstruction* constant) override; Status HandleCopy(HloInstruction* copy) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override; - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleSelect(HloInstruction* select) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; + Status HandleSort(HloInstruction* sort) override; Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override; - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, const Window& window, - HloComputation* function) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleReduceWindow(HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override; Status HandleSend(HloInstruction* send) override; - Status HandleSlice(HloInstruction* slice, - HloInstruction* /*operand*/) override; - Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* /*operand*/, - HloInstruction* /*start_indices*/) override; - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* /*operand*/, - HloInstruction* /*update*/, - HloInstruction* /*start_indices*/) override; + Status HandleSlice(HloInstruction* slice) override; + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override; Status HandleRecv(HloInstruction* recv) override; Status HandlePad(HloInstruction* pad) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice static_operands) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleMap(HloInstruction* map) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) override; + Status HandleCustomCall(HloInstruction* custom_call) override; Status HandleWhile(HloInstruction* xla_while) override; - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; + Status HandleConcatenate(HloInstruction* concatenate) override; Status FinishVisit(HloInstruction* root) override; Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* visited) override; + Status Postprocess(HloInstruction* hlo) override; private: // Private helper to initialize an IR function for the computation. @@ -220,8 +199,8 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Gets the IR Value emitted previously for the given hlo. // - // Prefer calling GetIrArrayForOp if the value you're reading is a buffer, - // because GetIrArrayForOp annotates buffer's loads/stores with noalias + // Prefer calling GetIrArrayFor if the value you're reading is a buffer, + // because GetIrArrayFor annotates buffer's loads/stores with noalias // metadata. // // Make sure to call this only when you're certain a value *was* emitted - if @@ -229,7 +208,11 @@ class IrEmitter : public DfsHloVisitorWithDefault { llvm::Value* GetEmittedValueFor(const HloInstruction* hlo); // Gets an IrArray representing the given hlo. - llvm_ir::IrArray GetIrArrayForOp(const HloInstruction* hlo); + llvm_ir::IrArray GetIrArrayFor(const HloInstruction* hlo); + + // Gets a list of IrArrays, one for each of hlo's operands. + std::vector GetIrArraysForOperandsOf( + const HloInstruction* hlo); // Augments IrArray with aliasing information. void AddAliasingInformationToIrArray(const HloInstruction& hlo, @@ -240,6 +223,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { // Convenience function to get the IR type matching the given shape. llvm::Type* IrShapeType(const Shape& shape); + // Returns an array of compute function parameter types. + std::vector GetComputeFunctionParams(); + // Get the llvm::Value* that represents the "retval" argument of the // computation function being emitted by this emitter. llvm::Argument* GetResultArgument(); @@ -304,7 +290,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { void EmitArrayFunctionCallInto( llvm::Function* function, tensorflow::gtl::ArraySlice parameter_addresses, - llvm::Value* return_value, tensorflow::StringPiece name); + llvm::Value* return_value_buffer, tensorflow::StringPiece name); // Array function call emitter. Returns a Value for the function's return // value buffer address. The return value buffer is alloca'ed by this @@ -314,6 +300,18 @@ class IrEmitter : public DfsHloVisitorWithDefault { tensorflow::gtl::ArraySlice parameter_addresses, tensorflow::StringPiece name); + // Returns an array of compute function call arguments. + std::vector GetArrayFunctionCallArguments( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* return_value_buffer, tensorflow::StringPiece name); + + // Emits a call to a runtime fork/join function which dispatches parallel + // calls to 'parallel_function' (and joins threads before returning). + Status EmitParallelForkJoin( + tensorflow::gtl::ArraySlice parameter_addresses, + llvm::Value* output_address, HloComputation* computation, + llvm::Function* parallel_function); + // Verifies that the element types of all of the given operand instructions // match and are of one of the given supported types. Status ElementTypesSameAndSupported( @@ -353,11 +351,10 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitMemcpy(const HloInstruction& source, const HloInstruction& destination); - // Emit IR to compute the target address of the buffer for the given op. - // The returned Value is a pointer to a IR type that represents the op's - // element type. - StatusOr EmitTargetAddressForOp( - const HloInstruction* op, const ShapeIndex& shape_index = {}); + // Emits IR to compute the target address of the buffer for the given op. + // After calling this function, you can get a pointer to this buffer by + // calling GetIrArrayForOp or GetEmittedValueFor. + Status EmitTargetAddressForOp(const HloInstruction* op); // Structurizes "array_elements" into an MD array that represents "shape". // This is a recursive function, and "dimension_index" indicates the index of @@ -447,10 +444,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { const llvm_ir::IrArray& target_array, const llvm_ir::IrArray& source_array); - // Name of the computation entry function. This function serves as the - // top-level "main" of the computation and will be invoked by the JIT. - string entry_function_name_; - // Assignment of the temporary buffers needed by the computation and their // shape information. const BufferAssignment& assignment_; @@ -592,12 +585,6 @@ class IrEmitter : public DfsHloVisitorWithDefault { Status EmitXfeedTransfer(XfeedKind kind, const Shape& shape, llvm::Value* program_buffer_address); - // Returns true if the current function being emitted is called in a - // parallel context (returns false otherwise). - bool IsParallelContext() { - return parallel_cpu_backend_ && is_top_level_computation_; - } - const HloModuleConfig& hlo_module_config_; const bool parallel_cpu_backend_; @@ -606,6 +593,9 @@ class IrEmitter : public DfsHloVisitorWithDefault { TargetMachineFeatures target_machine_features_; + int64 external_global_constant_counter_ = 0; + ExternalConstantPool* external_constant_pool_; + TF_DISALLOW_COPY_AND_ASSIGN(IrEmitter); }; diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc index f85459c79cc03bca14e335b833acb3efba3a3053..c446b6b792a042da2500ea6a175fdca4c70bcab6 100644 --- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc @@ -78,10 +78,10 @@ Status CpuLayoutAssignment::AddBackendConstraints( }; const HloComputation* computation = constraints->computation(); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kConvolution && PotentiallyImplementedAsEigenConvolution(*instruction)) { - const HloInstruction* convolution = instruction.get(); + const HloInstruction* convolution = instruction; const HloInstruction* lhs_instruction = convolution->operand(0); const HloInstruction* rhs_instruction = convolution->operand(1); @@ -102,49 +102,52 @@ Status CpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( constraints->SetInstructionLayout(output_shape, convolution)); } else if (should_make_rhs_col_major(*instruction)) { - auto* dot = instruction.get(); + auto* dot = instruction; const auto& rhs_shape = dot->operand(1)->shape(); TF_RETURN_IF_ERROR( constraints->SetOperandLayout(col_major_shape(rhs_shape), dot, 1)); } else if (PotentiallyImplementedAsEigenDot(*instruction)) { - const HloInstruction* dot = instruction.get(); - const HloInstruction* lhs_instruction = dot->operand(0); - const HloInstruction* rhs_instruction = dot->operand(1); - + const HloInstruction* dot = instruction; // In order to implement `dot` with Eigen dot, the layouts of the lhs, // rhs, and output need to be row-major. // // These constraints are not hard constraints. Ideally, we should decide // which layouts to choose according to some cost model. Shape output_shape(row_major_shape(dot->shape())); + + const HloInstruction* lhs_instruction = dot->operand(0); Shape lhs_shape(row_major_shape(lhs_instruction->shape())); - Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); + + // dot is a kDot or a kTransposeDot fusion node. In the latter case, if + // it represents X @ X, it may have just one operand. + if (dot->operand_count() > 1) { + const HloInstruction* rhs_instruction = dot->operand(1); + Shape rhs_shape(row_major_shape(rhs_instruction->shape())); + TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); + } // Set layouts of the instructions' shapes. - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(lhs_shape, dot, 0)); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout(rhs_shape, dot, 1)); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(output_shape, dot)); } else { for (int64 operand_no = 0; operand_no < instruction->operand_count(); ++operand_no) { // Skip operands which already have a constraint. - if (constraints->OperandLayout(instruction.get(), operand_no) != - nullptr) { + if (constraints->OperandLayout(instruction, operand_no) != nullptr) { continue; } // Skip over forwarded operands. - if (constraints->OperandBufferForwarded(instruction.get(), - operand_no)) { + if (constraints->OperandBufferForwarded(instruction, operand_no)) { continue; } Shape operand_shape( row_major_shape(instruction->operand(operand_no)->shape())); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - operand_shape, instruction.get(), operand_no)); + operand_shape, instruction, operand_no)); } // Skip over the root instruction for the top-level computation. if (computation->parent()->entry_computation() == computation && - computation->root_instruction() == instruction.get()) { + computation->root_instruction() == instruction) { continue; } // Skip instructions which don't produce array shapes (tuples, opaque, diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc index b49047283119fb2f10b9f68eaa37a7bdc27f63a6..81c29e4726c7be53b433be896f558f502e43c885 100644 --- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc +++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc @@ -52,7 +52,7 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module, llvm::IRBuilder<> ir_builder(vector_tanh_body); llvm::FastMathFlags fast_math_flags; - fast_math_flags.setUnsafeAlgebra(); + fast_math_flags.setFast(); ir_builder.setFastMathFlags(fast_math_flags); llvm::Value* input = &*vector_tanh_function->arg_begin(); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc index 40fa3a67bdec3953003ba8f98f2a19a9082a82c5..aff61296ced47a911ded207f611747564b5ac7eb 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc @@ -56,16 +56,16 @@ namespace cpu { ParallelCpuExecutable::ParallelCpuExecutable( std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> function_names, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr> function_names, std::unordered_map hlo_to_profile_idx, std::unordered_map> aligned_constants) : Executable(std::move(hlo_module)), jit_(std::move(jit)), assignment_(std::move(assignment)), - functions_names_(std::move(function_names)), + function_names_(std::move(function_names)), hlo_to_profile_idx_(std::move(hlo_to_profile_idx)), aligned_constants_(std::move(aligned_constants)) {} @@ -102,11 +102,11 @@ namespace { // in 'pending' on 'thread_pool' (storing resulting data in 'results'). class Executor { public: - Executor(const std::map& functions, + Executor(const HloInstructionMap& functions, const ServiceExecutableRunOptions* run_options, std::list* pending, - std::map* results, void** temps_array, - uint64* profile_counters_array, BufferAssignment* assignment) + HloInstructionMap* results, void** temps_array, + uint64* profile_counters_array, const BufferAssignment* assignment) : functions_(functions), run_options_(run_options), pending_(pending), @@ -142,14 +142,14 @@ class Executor { const void** GetOperandBuffers(HloInstruction* instruction); // Arguments passed into Executor. - const std::map& functions_; + const HloInstructionMap& functions_; const ServiceExecutableRunOptions* run_options_; std::list* pending_; - std::map* results_; + HloInstructionMap* results_; void** temps_array_; uint64* profile_counters_array_; tensorflow::thread::ThreadPool* thread_pool_; - BufferAssignment* assignment_; + const BufferAssignment* assignment_; // Members used to manage instruction execution. tensorflow::mutex completion_queue_lock_; @@ -377,7 +377,6 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( HloExecutionProfile* hlo_execution_profile) { std::vector argument_buffers(arguments.size()); for (int i = 0; i < arguments.size(); ++i) { - TF_RET_CHECK(!ShapeUtil::IsTuple(arguments[i]->shape())); argument_buffers[i] = arguments[i]->buffer(/*index=*/{}); } return ExecuteComputeFunctions(run_options, argument_buffers, buffers, @@ -401,8 +400,8 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } // Resolve functions for all the HLO instructions ahead of time. - std::map functions; - for (auto& entry : *functions_names_) { + HloInstructionMap functions; + for (auto& entry : *function_names_) { tensorflow::mutex_lock lock(jit_mutex_); HloInstruction* instruction = entry.first; llvm::JITSymbol sym = jit_->FindSymbol(entry.second); @@ -413,7 +412,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( } // Map containing pointers to result buffers for each instruction. - std::map results; + HloInstructionMap results; uint64 start_micros = tensorflow::Env::Default()->NowMicros(); @@ -464,7 +463,7 @@ Status ParallelCpuExecutable::ExecuteComputeFunctions( for (auto hlo_prof_idx : hlo_to_profile_idx_) { const HloInstruction* hlo = hlo_prof_idx.first; uint64 cycles_taken = profile_counters[hlo_prof_idx.second]; - hlo_execution_profile->AddProfileResult(hlo, cycles_taken); + hlo_execution_profile->SetCyclesTakenBy(hlo, cycles_taken); } } @@ -546,10 +545,9 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( DeviceMemoryAllocator* memory_allocator = run_options->allocator(); std::vector buffers(assignment_->Allocations().size()); - TF_ASSIGN_OR_RETURN(std::unique_ptr result_buffer, - ShapedBuffer::MakeShapedBuffer( - result_shape(), stream->parent()->platform(), - stream->parent()->device_ordinal())); + auto result_buffer = + MakeUnique(result_shape(), stream->parent()->platform(), + stream->parent()->device_ordinal()); TF_RETURN_IF_ERROR(AllocateBuffers( memory_allocator, stream->parent()->device_ordinal(), &buffers)); @@ -557,15 +555,14 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers, hlo_execution_profile)); - // Copy DeviceMemoryBase values which contain the array(s) of the result into - // the respective location in ShapedBuffer which is returned to the caller. + // Copy DeviceMemoryBase values which into the respective location in + // ShapedBuffer which is returned to the caller. std::vector buffers_in_result(assignment_->Allocations().size(), false); TF_RETURN_IF_ERROR( result_buffer->mutable_shape_index_to_buffer_entry() ->ForEachMutableElementWithStatus( [&buffers, &buffers_in_result, &result_buffer, this]( const ShapeIndex& index, size_t* buffer_entry) { - if (ShapeUtil::IsLeafIndex(result_buffer->shape(), index)) { const auto& sources = this->GetRootPointsToSet().element(index); // The points to set is unambiguous so the set should be a @@ -590,7 +587,6 @@ StatusOr> ParallelCpuExecutable::ExecuteOnStream( *buffer_entry = result_buffer->mutable_buffers()->size(); result_buffer->mutable_buffers()->push_back(buffer); buffers_in_result[buffer_index] = true; - } return Status::OK(); })); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h index d9200e13ed2ae8ed8afc4e4c7475e72aed4ae3c7..db16aaf48b0ef2aaa727c1bd0106bc51d1a65095 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h @@ -49,9 +49,9 @@ class ParallelCpuExecutable : public Executable { public: ParallelCpuExecutable( std::unique_ptr jit, - std::unique_ptr assignment, - std::unique_ptr hlo_module, - std::unique_ptr> instruction_functions, + std::unique_ptr assignment, + std::unique_ptr hlo_module, + std::unique_ptr> function_names, std::unordered_map hlo_to_profile_idx, std::unordered_map> @@ -129,10 +129,10 @@ class ParallelCpuExecutable : public Executable { // The JIT containing compiled modules. tensorflow::mutex jit_mutex_; - std::unique_ptr jit_ GUARDED_BY(jit_mutex_); + const std::unique_ptr jit_ GUARDED_BY(jit_mutex_); // Buffer assignment for the buffers we need to allocate. - std::unique_ptr assignment_; + const std::unique_ptr assignment_; // The LLVM IR, in string format, of the unoptimized module generated for this // ParallelCpuExecutable. We save a string instead of an llvm::Module* because @@ -141,7 +141,7 @@ class ParallelCpuExecutable : public Executable { string ir_module_string_; // Map containing the JITted function names for each HLO instruction. - std::unique_ptr> functions_names_; + const std::unique_ptr> function_names_; // Maps HLOs to their index into the profile counter array. const std::unordered_map hlo_to_profile_idx_; @@ -149,7 +149,8 @@ class ParallelCpuExecutable : public Executable { // Map from HLO Constant instructions to a pointer to their literal data. // The data stored in the protocol buffer might be insufficiently aligned, // we create a sufficiently aligned copy and store it in this map. - std::unordered_map> + const std::unordered_map> aligned_constants_; TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable); diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc new file mode 100644 index 0000000000000000000000000000000000000000..c2213c8f2ef592c537daf9abe2ffa10b83a8fa4c --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -0,0 +1,251 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h" + +#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" +#include "tensorflow/compiler/xla/service/cpu/shape_partition.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" + +namespace xla { +namespace cpu { + +class SimpleCostModel : public ParallelCostModel { + public: + SimpleCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_(shape_size) {} + ~SimpleCostModel() override {} + + int64 GetParallelTaskCount(HloInstruction* instruction) override { + // Simple cost model based on hlo size and typical L2 cache size. + const int64 instruction_cost = shape_size_(instruction->shape()); + const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism_, + std::max(1LL, instruction_cost / min_cost_per_thread)); + } + + private: + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; +}; + +class DefaultCostModel : public ParallelCostModel { + public: + DefaultCostModel(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + std::unique_ptr cost_analysis) + : max_parallelism_(max_parallelism), + shape_size_(shape_size), + cost_analysis_(std::move(cost_analysis)) {} + ~DefaultCostModel() override {} + + int64 GetParallelTaskCount(HloInstruction* instruction) override { + // Parameters for parallel task count computation. + int64 instruction_cost; + int64 min_cost_per_thread; + int64 max_parallelism; + // Calculate flops-to-bytes-ratio for 'instruction'. + const int64 bytes_accessed = + std::max(1LL, cost_analysis_->bytes_accessed(*instruction)); + const float flops_to_bytes_ratio = + cost_analysis_->flop_count(*instruction) / + static_cast(bytes_accessed); + // Check for I/O bound instructions. + if (flops_to_bytes_ratio <= 1.0) { + // Limit max parallelism for I/O bound instructions by assuming a + // sub-linear scaling function (fit based on empirical benchmark results). + // TODO(29630486) Develop system bandwidth model. + max_parallelism = + std::ceil(std::sqrt(tensorflow::port::NumSchedulableCPUs())); + // Use shape size instruction cost and L2 cache size min per-thread cost. + instruction_cost = shape_size_(instruction->shape()); + min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size. + } else { + // Use max parallelism for compute bound instructions. + max_parallelism = max_parallelism_; + // Calculate the instruction cost in cycles. + // TODO(29630486) Improve on this linear cost model. + // Consider making 'min_cost_per_thread' be a function of the target + // bandwidth limit for instructions with low arithmetic complexity. + instruction_cost = + 1 * cost_analysis_->flop_count(*instruction) + + 2 * cost_analysis_->transcendental_count(*instruction) + + 10 * cost_analysis_->bytes_accessed(*instruction); + // Minimum per-thread cost is 100us of work on a 2GHz core. + min_cost_per_thread = 100000; + } + // Return target parallel task count in [1, max_parallelism_]. + return std::min(max_parallelism, + std::max(1LL, instruction_cost / min_cost_per_thread)); + } + + private: + const int64 max_parallelism_; + const HloCostAnalysis::ShapeSizeFunction shape_size_; + const std::unique_ptr cost_analysis_; +}; + + +ParallelTaskAssignment::ParallelTaskAssignment( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) { + VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; + // Run cost analysis on 'module'. + auto cost_analysis = MakeUnique(shape_size); + HloComputation* computation = module->entry_computation(); + Status status = computation->root_instruction()->Accept(cost_analysis.get()); + if (status.ok()) { + // Set default cost model based on 'cost_analysis'. + cost_model_.reset(new DefaultCostModel(max_parallelism, shape_size, + std::move(cost_analysis))); + } else { + // Fall back to a simple cost model based on hlo size and L2 cache size. + // Note that HloCostAnalysis can returns an error status (likely because + // HLOs like CustomCall are not yet implemented in the HloCostAnalysis). + cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size)); + } +} + +int64 ParallelTaskAssignment::GetTargetParallelTaskCount( + HloInstruction* instruction) { + // Currently, we do not assign parallel tasks to instructions with at least + // one of the following properties: + // *) Internal threading (library calls to kConv, kDot, and kCustomCall). + // *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot). + // *) Tuple-shaped. + // TODO(b/27458679) Parallelize instructions which are skipped here. + if (instruction->opcode() == HloOpcode::kParameter || + instruction->opcode() == HloOpcode::kConstant || + instruction->opcode() == HloOpcode::kCall || + instruction->opcode() == HloOpcode::kCustomCall || + instruction->opcode() == HloOpcode::kSelectAndScatter || + instruction->opcode() == HloOpcode::kGetTupleElement || + instruction->opcode() == HloOpcode::kBitcast || + (instruction->opcode() == HloOpcode::kConvolution && + PotentiallyImplementedAsEigenConvolution(*instruction)) || + PotentiallyImplementedAsEigenDot(*instruction) || + (instruction->opcode() == HloOpcode::kFusion && + instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) || + ShapeUtil::IsTuple(instruction->shape())) { + return 1; + } + // Consult 'cost_model_' to compute target parallel task count. + return cost_model_->GetParallelTaskCount(instruction); +} + +StatusOr ParallelTaskAssigner::Run(HloModule* module) { + XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); + XLA_VLOG_LINES(3, module->ToString()); + + // Compute target parallel task counts for all instructions in 'module'. + HloToParallelTasks hlo_to_parallel_tasks; + ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); + + // Assign parallel tasks to target specific instructions in 'module'. + // TODO(b/27458679) Support inter-op parallelism. + bool changed = AssignParallelTasks(module, hlo_to_parallel_tasks); + + XLA_VLOG_LINES(2, "ParallelTaskAssigner EXIT"); + XLA_VLOG_LINES(3, module->ToString()); + return changed; +} + +bool ParallelTaskAssigner::AssignParallelTasks( + HloModule* module, const HloToParallelTasks& hlo_to_parallel_tasks) { + return AssignParallelTasksHelper(module, module->entry_computation(), + hlo_to_parallel_tasks); +} + +bool ParallelTaskAssigner::AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks) { + bool changed = false; + // Snapshot set of instructions because outlining modifies the set below. + std::vector instructions(computation->instructions().begin(), + computation->instructions().end()); + for (auto* instruction : instructions) { + // Assign parallel tasks to sub-computations for While and Call HLOs. + // TODO(b/27458679) Evaluate alternative intra-op parallelsim placement, + // and support other callable computations like reduce. + if (instruction->opcode() == HloOpcode::kWhile) { + changed |= AssignParallelTasksHelper(module, instruction->while_body(), + hlo_to_parallel_tasks); + continue; + } else if (instruction->opcode() == HloOpcode::kCall) { + changed |= AssignParallelTasksHelper(module, instruction->to_apply(), + hlo_to_parallel_tasks); + continue; + } + // Skip if no parallel tasks were computed in first pass. + auto it = hlo_to_parallel_tasks.find(instruction); + if (it == hlo_to_parallel_tasks.end()) { + continue; + } + // Get target parallel task count computed for 'instruction'. + const int64 target_parallel_task_count = (*it).second; + // Assign feasible dimension partitions (based on actual dimension sizes). + auto dim_partition_counts = ShapePartitionAssigner(instruction->shape()) + .Run(target_parallel_task_count); + const int64 total_partition_count = + ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts); + if (total_partition_count <= 1) { + // Feasible partition calculation resulting in no partitioning, so skip. + continue; + } + + // Outline 'instruction' in 'computation' for parallel task assignment. + auto* call = module->OutlineExpressionFromComputation( + {instruction}, + tensorflow::strings::StrCat("parallel_", instruction->name()), + computation); + + // Set assigned dimension partitioning to 'instruction'. + auto* new_root = call->to_apply()->root_instruction(); + new_root->set_outer_dimension_partitions(dim_partition_counts); + + VLOG(2) << "Assigned parallel task count: " << total_partition_count + << " to instruction: " << new_root->name() + << " parent: " << new_root->parent()->name(); + changed = true; + } + return changed; +} + +void ParallelTaskAssigner::ComputeTargetParallelTasks( + HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + // Compute parallel task counts for all instructions in 'module'. + for (auto* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + for (auto* instruction : computation->instructions()) { + // Query ParallelTaskAssignment for target parallel task count. + const int64 target_parallel_task_count = + parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + if (target_parallel_task_count > 1) { + hlo_to_parallel_tasks->insert( + {instruction, target_parallel_task_count}); + } + } + } +} + +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h new file mode 100644 index 0000000000000000000000000000000000000000..e036da5784f6151eb3b01107ec7f3ab820071a60 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -0,0 +1,104 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ + +#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { +namespace cpu { + +// Simple interface for different parallel cost model implementations. +class ParallelCostModel { + public: + virtual ~ParallelCostModel() = default; + virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0; +}; + +// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'. +class ParallelTaskAssignment { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssignment( + const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module); + ~ParallelTaskAssignment() {} + + // Computes and returns the target parallel task count for 'instruction'. + int64 GetTargetParallelTaskCount(HloInstruction* instruction); + + private: + std::unique_ptr cost_model_; +}; + +// ParallelTaskAssigner computes target parallel task counts for all HLOs +// in the module, then assigns parallel task counts to HLOs in the entry +// computation, or to HLOs in embedded computations invoked by (potentially +// nested) kWhile or kCall instructions. +// Each HLO which is assigned parallel task counts is outlined into its +// own embedded computation, which is compiled as a parallel compute function, +// and which is invoked from a kCall instruction that is lowered in codegen to +// a runtime parallel fork/join call. +class ParallelTaskAssigner : public HloPassInterface { + public: + // 'max_parallelism': the maximum parallel task count per instruction. + // 'shape_size': shape size function used by HloCostAnalysis during parallel + // task assignment. + // 'module': the containing HloModule. + ParallelTaskAssigner(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module) + : parallel_task_assignment_(max_parallelism, shape_size, module) {} + ~ParallelTaskAssigner() override {} + + tensorflow::StringPiece name() const override { + return "cpu-parallel-task-assigner"; + } + + // Run parallel task assigner on 'module'. + // Returns true if the computation was changed, false otherwise. + StatusOr Run(HloModule* module) override; + + private: + using HloToParallelTasks = std::unordered_map; + + // Assigns target parallel tasks from 'hlo_to_parallel_tasks' to HLOs in + // 'module'. + // Returns true if the computation was changed, false otherwise. + bool AssignParallelTasks(HloModule* module, + const HloToParallelTasks& hlo_to_parallel_tasks); + bool AssignParallelTasksHelper( + HloModule* module, HloComputation* computation, + const HloToParallelTasks& hlo_to_parallel_tasks); + + // Computes target parallel task counts (returned in 'parallel_task_counts') + // for parallelizable instructions in 'module'. + void ComputeTargetParallelTasks(HloModule* module, + HloToParallelTasks* hlo_to_parallel_tasks); + + ParallelTaskAssignment parallel_task_assignment_; +}; + +} // namespace cpu +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_ diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc new file mode 100644 index 0000000000000000000000000000000000000000..d03da46575b331de113cc5f33c2b4267504e8308 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.cc @@ -0,0 +1,97 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/executable_run_options.h" +#include "tensorflow/core/lib/core/blocking_counter.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +using tensorflow::int32; +using tensorflow::int64; +using tensorflow::uint64; + +using ComputeFunctionType = void (*)(void*, const void*, const void**, void**, + int64*, uint64*); + +// Dispatches 'num_partitions - 1' calls to 'function_ptr' in parallel. +// Calls 'function_ptr' for first partition inline. +// Uses blocking counter to synchonize threads after parallel calls complete. +// +// The 'partitions' array has a total number of elements equal to +// 'num_partitions * num_partitioned_dims * 2' (the '2' is necessary to specify +// dimension start and limit indices). +// +// The 'partitions' array layout stores array elements in memory with dimension +// start limit as the most-minor dimension, followed by dimension, then +// partition. +// +// EX: Layout of 'partitions' array with 'num_partitions = 2', and +// 'num_partitioned_dims = 3' +// +// [partition0_dim0_start] +// [partition0_dim0_limit] +// [partition0_dim1_start] +// [partition0_dim1_limit] +// [partition0_dim2_start] +// [partition0_dim2_limit] +// [partition1_dim0_start] +// [partition1_dim0_limit] +// [partition1_dim1_start] +// [partition1_dim1_limit] +// [partition1_dim2_start] +// [partition1_dim2_limit] +// +void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, uint64* prof_counters, int32 num_partitions, + int64* partitions, int32 num_partitioned_dims, void* function_ptr) { + VLOG(2) << "ParallelForkJoin ENTRY" + << " num_partitions: " << num_partitions + << " num_partitioned_dims: " << num_partitioned_dims; + CHECK_GT(num_partitions, 1); + CHECK_GT(num_partitioned_dims, 0); + const xla::ExecutableRunOptions* run_options = + static_cast(run_options_ptr); + ComputeFunctionType function = + reinterpret_cast(function_ptr); + // Compute partition stride in 'partitions' array. + const int64 stride = 2 * num_partitioned_dims; + + // Dispatch 'num_partitions - 1' compute functions to run in parallel. + tensorflow::BlockingCounter bc(num_partitions - 1); + for (int32 i = 1; i < num_partitions; ++i) { + const int64 offset = i * stride; + run_options->intra_op_thread_pool()->enqueueNoNotification( + [i, function, result_ptr, run_options_ptr, params, temps, prof_counters, + partitions, offset, &bc]() { + function(result_ptr, run_options_ptr, params, temps, + &partitions[offset], prof_counters); + bc.DecrementCount(); + VLOG(3) << "ParallelForkJoin partition " << i << " done."; + }); + } + + // Call first compute function inline. + function(result_ptr, run_options_ptr, params, temps, &partitions[0], + prof_counters); + VLOG(3) << "ParallelForkJoin partition 0 done."; + bc.Wait(); + VLOG(2) << "ParallelForkJoin EXIT"; +} diff --git a/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h new file mode 100644 index 0000000000000000000000000000000000000000..fcf1cc62078d3847435a2e75e3ca9d109cf8b200 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/runtime_fork_join.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ + +#include "tensorflow/core/platform/types.h" + +extern "C" { + +// Dispatches 'num_partitions' parallel calls to 'function_ptr' and joins +// threads before returning. See comments in runtime_fork_join.cc for details. +extern void __xla_cpu_runtime_ParallelForkJoin( + void* result_ptr, const void* run_options_ptr, const void** params, + void** temps, tensorflow::uint64* prof_counters, + tensorflow::int32 num_partitions, tensorflow::int64* partitions, + tensorflow::int32 num_partitioned_dims, void* function_ptr); + +} // extern "C" + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_RUNTIME_FORK_JOIN_H_ diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc index c3c11df090e88c3c24104b66d28b3b16f03baa80..fdf02e5b422f75e256feec77470bb0d079e8ef1f 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.cc @@ -31,7 +31,9 @@ limitations under the License. #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_avx.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_neon.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime_sse4_1.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/cpu/runtime_conv2d.h" +#include "tensorflow/compiler/xla/service/cpu/runtime_fork_join.h" #include "tensorflow/compiler/xla/service/cpu/runtime_matmul.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_conv2d.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" @@ -42,90 +44,21 @@ namespace xla { namespace cpu { namespace { -// Converts a symbol 'name' into the form expected by dlsym(). -std::string CanonicalizeSymbol(const std::string& name) { -#if defined(__APPLE__) - // On Mac OS X, dlsym() expects names not to be prefixed with a leading - // underscore. - if (!name.empty() && name.front() == '_') { - return name.substr(1); - } -#endif - return name; -} - -class JITSymbolTable { +// A simple SymbolResolver that delegates to the host dynamic linker. +class SimpleResolver : public llvm::JITSymbolResolver { public: - JITSymbolTable() { Populate(); } - - void* Lookup(llvm::StringRef jit_symbol_name) const { - auto it = jit_symbol_table_.find(jit_symbol_name); - return it == jit_symbol_table_.end() ? nullptr : it->getValue(); - } - - static bool MustBeInTable(llvm::StringRef name) { - // In particular, names starting with - // runtime::kXlaCpuRuntimeSymbolNamePrefix should not be dlsym'ed. - return name.startswith(runtime::kXlaCpuRuntimeSymbolNamePrefix); - } - - private: - void AddJITSymbolToTable(llvm::StringRef jit_symbol_name, - llvm::StringRef cpp_symbol_name, - void* jit_symbol_value) { - // The JIT symbol name and the C++ symbol name (with an extern "C" linkage) - // need to match, otherwise AOT links will fail. - CHECK(jit_symbol_name == cpp_symbol_name); - CHECK(jit_symbol_table_.insert({jit_symbol_name, jit_symbol_value}).second); - } - - void Populate() { -#define ADD_JIT_SYMBOL_TO_TABLE(base_name) \ - do { \ - AddJITSymbolToTable( \ - xla::cpu::runtime::k##base_name##SymbolName, \ - "__xla_cpu_runtime_" #base_name, \ - reinterpret_cast(__xla_cpu_runtime_##base_name)); \ - } while (false) - - ADD_JIT_SYMBOL_TO_TABLE(AcquireInfeedBufferForDequeue); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseInfeedBufferAfterDequeue); - ADD_JIT_SYMBOL_TO_TABLE(AcquireOutfeedBufferForPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ReleaseOutfeedBufferAfterPopulation); - ADD_JIT_SYMBOL_TO_TABLE(ExpV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(LogV8F32AVX); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32SSE); - ADD_JIT_SYMBOL_TO_TABLE(ExpV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(LogV4F32NEON); - ADD_JIT_SYMBOL_TO_TABLE(EigenConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenMatMulF64); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedConvF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF32); - ADD_JIT_SYMBOL_TO_TABLE(EigenSingleThreadedMatMulF64); - -#undef ADD_JIT_SYMBOL_TO_TABLE - } - - llvm::StringMap jit_symbol_table_; -}; - -const JITSymbolTable& GetJITSymbolTable() { - static JITSymbolTable* symbol_table = new JITSymbolTable; - return *symbol_table; -} + explicit SimpleResolver(ExternalConstantPool* external_constant_pool) + : external_constant_pool_(external_constant_pool) {} -// A simple SymbolResolver that delegates to the host dynamic linker. -struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbol findSymbol(const std::string& name) override { - std::string canonical_name = CanonicalizeSymbol(name); - const JITSymbolTable& jit_symbol_table = GetJITSymbolTable(); - - void* func_addr = JITSymbolTable::MustBeInTable(canonical_name) - ? jit_symbol_table.Lookup(canonical_name) - : dlsym(RTLD_DEFAULT, canonical_name.c_str()); + if (const uint8* from_constant_pool = + external_constant_pool_->Find(string(name))) { + return llvm::JITEvaluatedSymbol( + reinterpret_cast(from_constant_pool), + llvm::JITSymbolFlags::None); + } + void* func_addr = CustomCallTargetRegistry::Global()->Lookup(name); if (func_addr == nullptr) { return nullptr; } @@ -136,6 +69,9 @@ struct SimpleResolver : public llvm::JITSymbolResolver { llvm::JITSymbol findSymbolInLogicalDylib(const std::string& name) override { return nullptr; } + + private: + ExternalConstantPool* external_constant_pool_; }; llvm::SmallVector DetectMachineAttributes() { @@ -205,7 +141,7 @@ SimpleOrcJIT::SimpleOrcJIT(const llvm::TargetOptions& target_options, SimpleOrcJIT::ModuleHandleT SimpleOrcJIT::AddModule( std::unique_ptr module) { auto handle = cantFail(compile_layer_.addModule( - std::move(module), MakeUnique())); + std::move(module), MakeUnique(external_constant_pool()))); module_handles_.push_back(handle); return handle; } @@ -238,5 +174,118 @@ llvm::JITSymbol SimpleOrcJIT::FindSymbol(const std::string& name) { return nullptr; } +namespace { +// Register some known symbols with the CustomCallTargetRegistry. +bool RegisterKnownJITSymbols() { + CustomCallTargetRegistry* registry = CustomCallTargetRegistry::Global(); + +#define REGISTER_CPU_RUNTIME_SYMBOL(base_name) \ + do { \ + auto* function_address = \ + reinterpret_cast(__xla_cpu_runtime_##base_name); \ + registry->Register(xla::cpu::runtime::k##base_name##SymbolName, \ + function_address); \ + CHECK_EQ( \ + tensorflow::StringPiece(xla::cpu::runtime::k##base_name##SymbolName), \ + "__xla_cpu_runtime_" #base_name); \ + } while (false) + + REGISTER_CPU_RUNTIME_SYMBOL(AcquireInfeedBufferForDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(AcquireOutfeedBufferForPopulation); + REGISTER_CPU_RUNTIME_SYMBOL(EigenConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedConvF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF32); + REGISTER_CPU_RUNTIME_SYMBOL(EigenSingleThreadedMatMulF64); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(ExpV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32NEON); + REGISTER_CPU_RUNTIME_SYMBOL(LogV4F32SSE); + REGISTER_CPU_RUNTIME_SYMBOL(LogV8F32AVX); + REGISTER_CPU_RUNTIME_SYMBOL(ParallelForkJoin); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseInfeedBufferAfterDequeue); + REGISTER_CPU_RUNTIME_SYMBOL(ReleaseOutfeedBufferAfterPopulation); + +#undef REGISTER_CPU_RUNTIME_SYMBOL + +#define REGISTER_LIBM_SYMBOL(name) \ + do { \ + /* Register both the F32 and F64 variants of the libm symbol. */ \ + registry->Register(#name "f", reinterpret_cast(name##f)); \ + registry->Register(#name, reinterpret_cast(name)); \ + } while (false) + + REGISTER_LIBM_SYMBOL(acos); + REGISTER_LIBM_SYMBOL(acosh); + REGISTER_LIBM_SYMBOL(asin); + REGISTER_LIBM_SYMBOL(asinh); + REGISTER_LIBM_SYMBOL(atan); + REGISTER_LIBM_SYMBOL(atan2); + REGISTER_LIBM_SYMBOL(atanh); + REGISTER_LIBM_SYMBOL(cbrt); + REGISTER_LIBM_SYMBOL(ceil); + REGISTER_LIBM_SYMBOL(copysign); + REGISTER_LIBM_SYMBOL(cos); + REGISTER_LIBM_SYMBOL(cosh); + REGISTER_LIBM_SYMBOL(erf); + REGISTER_LIBM_SYMBOL(erfc); + REGISTER_LIBM_SYMBOL(exp); + REGISTER_LIBM_SYMBOL(exp2); + REGISTER_LIBM_SYMBOL(expm1); + REGISTER_LIBM_SYMBOL(fabs); + REGISTER_LIBM_SYMBOL(fdim); + REGISTER_LIBM_SYMBOL(floor); + REGISTER_LIBM_SYMBOL(fma); + REGISTER_LIBM_SYMBOL(fmax); + REGISTER_LIBM_SYMBOL(fmin); + REGISTER_LIBM_SYMBOL(fmod); + REGISTER_LIBM_SYMBOL(frexp); + REGISTER_LIBM_SYMBOL(hypot); + REGISTER_LIBM_SYMBOL(ilogb); + REGISTER_LIBM_SYMBOL(ldexp); + REGISTER_LIBM_SYMBOL(lgamma); + REGISTER_LIBM_SYMBOL(llrint); + REGISTER_LIBM_SYMBOL(llround); + REGISTER_LIBM_SYMBOL(log); + REGISTER_LIBM_SYMBOL(log10); + REGISTER_LIBM_SYMBOL(log1p); + REGISTER_LIBM_SYMBOL(log2); + REGISTER_LIBM_SYMBOL(logb); + REGISTER_LIBM_SYMBOL(lrint); + REGISTER_LIBM_SYMBOL(lround); + REGISTER_LIBM_SYMBOL(modf); + REGISTER_LIBM_SYMBOL(nan); + REGISTER_LIBM_SYMBOL(nearbyint); + REGISTER_LIBM_SYMBOL(nextafter); + REGISTER_LIBM_SYMBOL(nexttoward); + REGISTER_LIBM_SYMBOL(pow); + REGISTER_LIBM_SYMBOL(remainder); + REGISTER_LIBM_SYMBOL(remquo); + REGISTER_LIBM_SYMBOL(rint); + REGISTER_LIBM_SYMBOL(round); + REGISTER_LIBM_SYMBOL(scalbln); + REGISTER_LIBM_SYMBOL(scalbn); + REGISTER_LIBM_SYMBOL(sin); + REGISTER_LIBM_SYMBOL(sincos); + REGISTER_LIBM_SYMBOL(sinh); + REGISTER_LIBM_SYMBOL(sqrt); + REGISTER_LIBM_SYMBOL(tan); + REGISTER_LIBM_SYMBOL(tanh); + REGISTER_LIBM_SYMBOL(tgamma); + REGISTER_LIBM_SYMBOL(trunc); + +#undef REGISTER_LIBM_SYMBOL + + registry->Register("memcpy", reinterpret_cast(memcpy)); + registry->Register("memmove", reinterpret_cast(memmove)); + registry->Register("memset", reinterpret_cast(memset)); + return true; +} + +bool unused = RegisterKnownJITSymbols(); +} // namespace + } // namespace cpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h index e476c0e3812cc0fb2a2d633832374b3165ca072a..ded01e9e4d7442296f7406dd035e6ab385458238 100644 --- a/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h +++ b/tensorflow/compiler/xla/service/cpu/simple_orc_jit.h @@ -27,6 +27,7 @@ limitations under the License. #include "llvm/Target/TargetMachine.h" #include "tensorflow/compiler/xla/service/cpu/compiler_functor.h" #include "tensorflow/compiler/xla/service/cpu/disassembler.h" +#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -90,6 +91,10 @@ class SimpleOrcJIT { llvm::TargetMachine* target_machine() const { return target_machine_.get(); } + ExternalConstantPool* external_constant_pool() { + return &external_constant_pool_; + } + private: std::vector module_handles_; std::unique_ptr target_machine_; @@ -97,6 +102,7 @@ class SimpleOrcJIT { const llvm::DataLayout data_layout_; ObjLayerT object_layer_; CompileLayerT compile_layer_; + ExternalConstantPool external_constant_pool_; }; } // namespace cpu diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..d124f74d19d83269be96ee34a6b4b2a8d00a978f --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser.cc @@ -0,0 +1,115 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/defuser.h" + +#include +#include +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/service/call_graph.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +namespace { + +// Copy all the instructions in the given fusion instruction into the fusion +// instruction's parent computation and replace the use of the fusion +// instruction with the copy of the fusion expression root. +Status Defuse(HloInstruction* fusion_instruction) { + VLOG(2) << "Defusing instruction: " << fusion_instruction->ToString(); + + HloComputation* fused_computation = + fusion_instruction->fused_instructions_computation(); + + // A map from fused instruction to its defused clone. + tensorflow::gtl::FlatMap + defused_instructions; + // Initialize map to contain the fusion instruction parameters mapping + // to the operands of the fusion instruction. + for (int64 i = 0; i < fusion_instruction->operand_count(); ++i) { + defused_instructions[fused_computation->parameter_instruction(i)] = + fusion_instruction->mutable_operand(i); + } + + // Create a clone of each instruction of the fused computation in the same + // computation as the fusion instruction itself. + // TODO(b/68227302): Moving instruction to new computation rather than + // cloning and deleting. + for (HloInstruction* fused_instruction : + fused_computation->MakeInstructionPostOrder()) { + if (fused_instruction->opcode() == HloOpcode::kParameter) { + continue; + } + std::vector new_operands; + for (HloInstruction* operand : fused_instruction->operands()) { + new_operands.push_back(defused_instructions.at(operand)); + } + HloInstruction* defused_instruction = + fusion_instruction->parent()->AddInstruction( + fused_instruction->CloneWithNewOperands(fused_instruction->shape(), + new_operands)); + defused_instructions[fused_instruction] = defused_instruction; + } + + TF_RETURN_IF_ERROR(fusion_instruction->ReplaceAllUsesWith( + defused_instructions.at(fusion_instruction->fused_expression_root()))); + + HloModule* module = fusion_instruction->parent()->parent(); + TF_RETURN_IF_ERROR( + fusion_instruction->parent()->RemoveInstruction(fusion_instruction)); + return module->RemoveEmbeddedComputation(fused_computation); +} + +} // namespace + +StatusOr Defuser::Run(HloModule* module) { + VLOG(1) << "Defusing module " << module->name(); + XLA_VLOG_LINES(2, "Before defusion:\n" + module->ToString()); + + bool changed = false; + std::unique_ptr call_graph = CallGraph::Build(module); + TF_RETURN_IF_ERROR(call_graph->VisitNodes( + [&](const CallGraphNode& call_graph_node) -> Status { + if (call_graph_node.computation()->IsFusionComputation()) { + TF_RET_CHECK(call_graph_node.caller_callsites().size() == 1); + HloInstruction* fusion_instruction = + call_graph_node.caller_callsites()[0].instruction(); + TF_RETURN_IF_ERROR(Defuse(fusion_instruction)); + changed = true; + } + return Status::OK(); + }, + /*visit_unreachable_nodes=*/true)); + + XLA_VLOG_LINES(2, "After defusion:\n" + module->ToString()); + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/defuser.h b/tensorflow/compiler/xla/service/defuser.h new file mode 100644 index 0000000000000000000000000000000000000000..56b28fd22da1ea6bc19f98e76f0f2ef4044cd3af --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which replaces all fusion instructions with the equivalent un-fused +// instructions. +class Defuser : public HloPassInterface { + public: + Defuser() {} + ~Defuser() override {} + tensorflow::StringPiece name() const override { return "defuser"; } + + // Run defusion on the given module. Returns whether the module was + // changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_DEFUSER_H_ diff --git a/tensorflow/compiler/xla/service/defuser_test.cc b/tensorflow/compiler/xla/service/defuser_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..32b5c5d35fae61ae6cb17fafcada1abd6c3c088c --- /dev/null +++ b/tensorflow/compiler/xla/service/defuser_test.cc @@ -0,0 +1,214 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/defuser.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class DefuserTest : public HloVerifiedTestBase { + protected: + // Returns the number of fusion instructions in the module. + int FusionCount() { + int count = 0; + for (HloComputation* computation : module().computations()) { + if (computation->IsFusionComputation()) { + count++; + } + } + return count; + } + + Defuser defuser_; + const Shape shape_ = ShapeUtil::MakeShape(F32, {2, 2}); +}; + +TEST_F(DefuserTest, NoFusionInstruction) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + + module().AddEntryComputation(builder.Build()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_FALSE(defuser_.Run(&module()).ValueOrDie()); +} + +TEST_F(DefuserTest, TrivialFusionInstructionAsRoot) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Parameter(), op::Parameter())); +} + +TEST_F(DefuserTest, TrivialFusionInstructionNotAsRoot) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Negate(op::Fusion())); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Negate(op::Add(op::Parameter(), op::Parameter()))); +} + +TEST_F(DefuserTest, NonTrivialFusionInstruction) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto param3 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); + auto div = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction( + {add2, constant, div, mul, sub, negate, add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(1, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Constant(), op::Divide())); +} + +TEST_F(DefuserTest, MultipleFusionInstructions) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto param3 = + builder.AddInstruction(HloInstruction::CreateParameter(2, shape_, "p2")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + auto sub = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kSubtract, add, negate)); + auto mul = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kMultiply, sub, param3)); + auto div = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kDivide, mul, param3)); + auto constant = builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}))); + auto add2 = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, constant, div)); + + auto computation = module().AddEntryComputation(builder.Build()); + computation->CreateFusionInstruction({add2, constant, div, mul}, + HloInstruction::FusionKind::kLoop); + computation->CreateFusionInstruction({sub, negate, add}, + HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(2, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Constant(), op::Divide())); +} + +TEST_F(DefuserTest, NestedFusionInstructions) { + auto builder = HloComputation::Builder(TestName()); + auto param0 = + builder.AddInstruction(HloInstruction::CreateParameter(0, shape_, "p0")); + auto param1 = + builder.AddInstruction(HloInstruction::CreateParameter(1, shape_, "p1")); + auto add = builder.AddInstruction( + HloInstruction::CreateBinary(shape_, HloOpcode::kAdd, param0, param1)); + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(shape_, HloOpcode::kNegate, add)); + + auto computation = module().AddEntryComputation(builder.Build()); + auto outer_fusion = computation->CreateFusionInstruction( + {negate, add}, HloInstruction::FusionKind::kLoop); + HloInstruction* fused_negate = outer_fusion->fused_expression_root(); + ASSERT_EQ(fused_negate->opcode(), HloOpcode::kNegate); + outer_fusion->fused_instructions_computation()->CreateFusionInstruction( + {fused_negate}, HloInstruction::FusionKind::kLoop); + + EXPECT_THAT(computation->root_instruction(), op::Fusion()); + + EXPECT_EQ(2, FusionCount()); + EXPECT_TRUE(defuser_.Run(&module()).ValueOrDie()); + EXPECT_EQ(0, FusionCount()); + + EXPECT_THAT(computation->root_instruction(), op::Negate(op::Add())); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/device_memory_allocator.h b/tensorflow/compiler/xla/service/device_memory_allocator.h index 391585a306d209488f335fff0dff02ff982365af..00caefab667cba6abfef200050ca18f229fc0320 100644 --- a/tensorflow/compiler/xla/service/device_memory_allocator.h +++ b/tensorflow/compiler/xla/service/device_memory_allocator.h @@ -33,7 +33,7 @@ class DeviceMemoryAllocator { public: // Parameter platform indicates which platform the allocator allocates memory // on. Must be non-null. - explicit DeviceMemoryAllocator(const perftools::gputools::Platform* platform) + explicit DeviceMemoryAllocator(perftools::gputools::Platform* platform) : platform_(platform) {} virtual ~DeviceMemoryAllocator() {} @@ -49,14 +49,14 @@ class DeviceMemoryAllocator { int device_ordinal, perftools::gputools::DeviceMemoryBase* mem) = 0; // Return the platform that the allocator allocates memory on. - const perftools::gputools::Platform* platform() const { return platform_; } + perftools::gputools::Platform* platform() const { return platform_; } // Can we call Deallocate() as soon as a computation has been scheduled on // a stream, or do we have to wait for the computation to complete first? virtual bool AllowsAsynchronousDeallocation() const = 0; protected: - const perftools::gputools::Platform* platform_; + perftools::gputools::Platform* platform_; }; // Default memory allocator for a platform which uses diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc index 6efd0bcee58d19b355b6c2afa6d9497f75ef4b3c..2172ae0a29626660e8abd29a789e0baa3831519d 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.cc @@ -24,37 +24,55 @@ limitations under the License. namespace xla { -Status DfsHloVisitor::HandleElementwiseUnary(HloInstruction* hlo) { +template +Status DfsHloVisitorBase::HandleElementwiseUnary( + HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseUnary: %s", HloOpcodeString(hlo->opcode()).c_str()); } -Status DfsHloVisitor::HandleElementwiseBinary(HloInstruction* hlo) { +template +Status DfsHloVisitorBase::HandleElementwiseBinary( + HloInstructionPtr hlo) { return Unimplemented("DfsHloVisitor::HandleElementwiseBinary: %s", HloOpcodeString(hlo->opcode()).c_str()); } -DfsHloVisitor::VisitState DfsHloVisitor::GetVisitState( +template +typename DfsHloVisitorBase::VisitState +DfsHloVisitorBase::GetVisitState( const HloInstruction& instruction) { return GetVisitState(instruction.unique_id()); } -void DfsHloVisitor::SetVisiting(const HloInstruction& instruction) { +template +void DfsHloVisitorBase::SetVisiting( + const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visiting: "; DCHECK(NotVisited(instruction)); visit_state_.SetState(instruction.unique_id(), VisitState::kVisiting); } -void DfsHloVisitor::SetVisited(const HloInstruction& instruction) { +template +void DfsHloVisitorBase::SetVisited( + const HloInstruction& instruction) { VLOG(3) << "marking HLO " << &instruction << " as visited: "; DCHECK(NotVisited(instruction) || IsVisiting(instruction)); visit_state_.SetState(instruction.unique_id(), VisitState::kVisited); } -Status DfsHloVisitor::Preprocess(HloInstruction* hlo) { return Status::OK(); } +template +Status DfsHloVisitorBase::Preprocess(HloInstructionPtr) { + return Status::OK(); +} -Status DfsHloVisitor::Postprocess(HloInstruction* visited) { +template +Status DfsHloVisitorBase::Postprocess(HloInstructionPtr) { return Status::OK(); } +// Explicit instantiations. +template class DfsHloVisitorBase; +template class DfsHloVisitorBase; + } // namespace xla diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index 2c16a1b9033f45742f80b91eb1695315bd13ed80..de3cd1544087686fa884fc22382aa4dff5256938 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -16,6 +16,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ #define TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ +#include #include #include "tensorflow/compiler/xla/literal_util.h" @@ -53,192 +54,176 @@ class HloInstruction; // // Note: this may change to an iterator in the future for flexibility purposes. // -// TODO(b/26548304): Stop passing in information about the visited -// instruction that is accessible from the instruction object itself. -class DfsHloVisitor { +// Users should not use this class directly, but use the type-aliases +// DfsHloVisitor/ConstDfsHloVisitor instead. +template +class DfsHloVisitorBase { + static_assert( + std::is_same::value || + std::is_same::value, + "Template argument expected to be HloInstruction* or const " + "HloInstruction*"); + public: - DfsHloVisitor() {} - virtual ~DfsHloVisitor() {} + DfsHloVisitorBase() {} + virtual ~DfsHloVisitorBase() {} // These routines are self-descriptive, see class comment for usage // information. - virtual Status HandleElementwiseUnary(HloInstruction* hlo); - virtual Status HandleElementwiseBinary(HloInstruction* hlo); - virtual Status HandleClamp(HloInstruction* clamp, HloInstruction* min, - HloInstruction* arg, HloInstruction* max) = 0; - virtual Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) = 0; - virtual Status HandleMaximum(HloInstruction* maximum) { - return HandleElementwiseBinary(maximum); + virtual Status HandleElementwiseUnary(HloInstructionPtr hlo); + virtual Status HandleElementwiseBinary(HloInstructionPtr hlo); + + virtual Status HandleClamp(HloInstructionPtr hlo) = 0; + virtual Status HandleSelect(HloInstructionPtr hlo) = 0; + virtual Status HandleMaximum(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleMinimum(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleConcatenate(HloInstructionPtr hlo) = 0; + virtual Status HandleConvert(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleCopy(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleComplex(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleMultiply(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleDot(HloInstructionPtr hlo) = 0; + virtual Status HandlePower(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleConvolution(HloInstructionPtr hlo) = 0; + virtual Status HandleCrossReplicaSum(HloInstructionPtr hlo) = 0; + virtual Status HandleCompare(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleAdd(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleDivide(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleRemainder(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleSubtract(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleAbs(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleAtan2(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleRound(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleSign(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleNegate(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleExp(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); } - virtual Status HandleMinimum(HloInstruction* minimum) { - return HandleElementwiseBinary(minimum); + virtual Status HandleFloor(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleCeil(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleLog(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleCos(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleSin(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleTanh(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleReal(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleImag(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleIsFinite(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleAnd(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleNot(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + virtual Status HandleOr(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleShiftLeft(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleShiftRightArithmetic(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); + } + virtual Status HandleShiftRightLogical(HloInstructionPtr hlo) { + return HandleElementwiseBinary(hlo); } - virtual Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) = 0; - virtual Status HandleConvert(HloInstruction* convert) { - return HandleElementwiseUnary(convert); - } - virtual Status HandleCopy(HloInstruction* copy) { - return HandleElementwiseUnary(copy); - } - virtual Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(multiply); - } - virtual Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) = 0; - virtual Status HandlePower(HloInstruction* power, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(power); - } - virtual Status HandleConvolution(HloInstruction* convolution, - HloInstruction* lhs, HloInstruction* rhs, - const Window& window) = 0; - virtual Status HandleCrossReplicaSum(HloInstruction* crs) = 0; - virtual Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(compare); - } - virtual Status HandleAdd(HloInstruction* add, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(add); - } - virtual Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(divide); - } - virtual Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(remainder); - } - virtual Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, - HloInstruction* rhs) { - return HandleElementwiseBinary(subtract); - } - virtual Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { - return HandleElementwiseUnary(abs); - } - virtual Status HandleRound(HloInstruction* round) { - return HandleElementwiseUnary(round); - } - virtual Status HandleSign(HloInstruction* sign, HloInstruction* operand) { - return HandleElementwiseUnary(sign); - } - virtual Status HandleNegate(HloInstruction* negate, HloInstruction* operand) { - return HandleElementwiseUnary(negate); - } - virtual Status HandleExp(HloInstruction* exp, HloInstruction* operand) { - return HandleElementwiseUnary(exp); - } - virtual Status HandleFloor(HloInstruction* floor, HloInstruction* operand) { - return HandleElementwiseUnary(floor); - } - virtual Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) { - return HandleElementwiseUnary(ceil); - } - virtual Status HandleLog(HloInstruction* log, HloInstruction* operand) { - return HandleElementwiseUnary(log); - } - virtual Status HandleCos(HloInstruction* cos, HloInstruction* operand) { - return HandleElementwiseUnary(cos); - } - virtual Status HandleSin(HloInstruction* sin, HloInstruction* operand) { - return HandleElementwiseUnary(sin); - } - virtual Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) { - return HandleElementwiseUnary(tanh); - } - virtual Status HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) { - return HandleElementwiseUnary(is_finite); - } - virtual Status HandleLogicalAnd(HloInstruction* logical_and, - HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_and); - } - virtual Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) { - return HandleElementwiseUnary(logical_not); - } - virtual Status HandleLogicalOr(HloInstruction* logical_or, - HloInstruction* lhs, HloInstruction* rhs) { - return HandleElementwiseBinary(logical_or); - } - virtual Status HandleReducePrecision(HloInstruction* reduce_precision) { - return HandleElementwiseUnary(reduce_precision); - } - - virtual Status HandleInfeed(HloInstruction* infeed) = 0; - virtual Status HandleOutfeed(HloInstruction* outfeed) = 0; - virtual Status HandleRng(HloInstruction* random, - RandomDistribution distribution) = 0; - virtual Status HandleReverse(HloInstruction* reverse, - HloInstruction* operand) = 0; - virtual Status HandleSort(HloInstruction* sort, HloInstruction* operand) = 0; - virtual Status HandleConstant(HloInstruction* constant, - const Literal& literal) = 0; - virtual Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) = 0; - virtual Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) = 0; - virtual Status HandleBitcast(HloInstruction* bitcast) = 0; - virtual Status HandleBroadcast(HloInstruction* broadcast) = 0; - virtual Status HandleReshape(HloInstruction* reshape) = 0; - virtual Status HandleTranspose(HloInstruction* transpose) = 0; - virtual Status HandleParameter(HloInstruction* parameter) = 0; - virtual Status HandleFusion(HloInstruction* fusion) = 0; - virtual Status HandleCall(HloInstruction* call) = 0; - virtual Status HandleCustomCall( - HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) = 0; - virtual Status HandleSlice(HloInstruction* slice, - HloInstruction* operand) = 0; - virtual Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* start_indices) = 0; - virtual Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) = 0; - virtual Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) = 0; - virtual Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice static_operands) = 0; - virtual Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, - const Window& window, - HloComputation* function) = 0; - virtual Status HandleSelectAndScatter(HloInstruction* instruction) = 0; - virtual Status HandleWhile(HloInstruction* xla_while) = 0; - - virtual Status HandlePad(HloInstruction* pad) = 0; - - virtual Status HandleSend(HloInstruction* send) = 0; - - virtual Status HandleRecv(HloInstruction* recv) = 0; - - virtual Status HandleBatchNormTraining( - HloInstruction* batch_norm_training) = 0; - - virtual Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) = 0; - - virtual Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) = 0; + + virtual Status HandleReducePrecision(HloInstructionPtr hlo) { + return HandleElementwiseUnary(hlo); + } + + virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleRng(HloInstructionPtr hlo) = 0; + virtual Status HandleReverse(HloInstructionPtr hlo) = 0; + virtual Status HandleSort(HloInstructionPtr hlo) = 0; + virtual Status HandleConstant(HloInstructionPtr hlo) = 0; + virtual Status HandleGetTupleElement(HloInstructionPtr hlo) = 0; + virtual Status HandleReduce(HloInstructionPtr hlo) = 0; + virtual Status HandleBitcast(HloInstructionPtr hlo) = 0; + virtual Status HandleBroadcast(HloInstructionPtr hlo) = 0; + virtual Status HandleReshape(HloInstructionPtr hlo) = 0; + virtual Status HandleTranspose(HloInstructionPtr hlo) = 0; + virtual Status HandleParameter(HloInstructionPtr hlo) = 0; + virtual Status HandleFusion(HloInstructionPtr hlo) = 0; + virtual Status HandleCall(HloInstructionPtr hlo) = 0; + virtual Status HandleCustomCall(HloInstructionPtr hlo) = 0; + virtual Status HandleSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleDynamicUpdateSlice(HloInstructionPtr hlo) = 0; + virtual Status HandleTuple(HloInstructionPtr hlo) = 0; + virtual Status HandleMap(HloInstructionPtr hlo) = 0; + virtual Status HandleReduceWindow(HloInstructionPtr hlo) = 0; + virtual Status HandleSelectAndScatter(HloInstructionPtr hlo) = 0; + virtual Status HandleWhile(HloInstructionPtr hlo) = 0; + + virtual Status HandlePad(HloInstructionPtr hlo) = 0; + + virtual Status HandleSend(HloInstructionPtr hlo) = 0; + + virtual Status HandleRecv(HloInstructionPtr hlo) = 0; + + virtual Status HandleBatchNormTraining(HloInstructionPtr hlo) = 0; + + virtual Status HandleBatchNormInference(HloInstructionPtr hlo) = 0; + + virtual Status HandleBatchNormGrad(HloInstructionPtr hlo) = 0; // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". - virtual Status FinishVisit(HloInstruction* root) = 0; + virtual Status FinishVisit(HloInstructionPtr root) = 0; // 3 possible visitation states of HLO instructions. Each instruction's // state only flows one way: kNotVisited -> kVisiting -> kVisited. @@ -296,7 +281,7 @@ class DfsHloVisitor { // // Overriding methods should call DfsHloVisitor::Preprocess before doing their // own preprocessing. - virtual Status Preprocess(HloInstruction* hlo); + virtual Status Preprocess(HloInstructionPtr hlo); // This method should be overridden by subclasses that wish to run some // operation on an op after its Handle* visitor method is called. See @@ -304,7 +289,7 @@ class DfsHloVisitor { // // Overriding methods should call DfsHloVisitor::Postprocess after doing their // own postprocessing. - virtual Status Postprocess(HloInstruction* visited); + virtual Status Postprocess(HloInstructionPtr hlo); private: class DFSVisitStates { @@ -345,9 +330,14 @@ class DfsHloVisitor { DFSVisitStates visit_state_; - TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitor); + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorBase); }; +// Users should use one of these two type aliases, which are the only two valid +// instantiations of DfsHloVisitorBase. +using DfsHloVisitor = DfsHloVisitorBase; +using ConstDfsHloVisitor = DfsHloVisitorBase; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_H_ diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index a5fe120598416235dff2af9d8a5c0ae64ac9edcc..7ce88be89dfe0746d9d05ca3d5c788f72ca74cd8 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -33,202 +33,189 @@ class HloComputation; class HloInstruction; // DfsHloVisitor with default action based on the HloInstruction being visited. -class DfsHloVisitorWithDefault : public DfsHloVisitor { +// Users should not use this class directly, but use the type aliases +// DfsHloVisitorWithDefault/ConstDfsHloVisitorWithDefault instead. +template +class DfsHloVisitorWithDefaultBase + : public DfsHloVisitorBase { public: - DfsHloVisitorWithDefault() {} - ~DfsHloVisitorWithDefault() override {} + DfsHloVisitorWithDefaultBase() {} + ~DfsHloVisitorWithDefaultBase() override {} // Default action performed on HloInstruction. - virtual Status DefaultAction(HloInstruction* hlo_instruction) = 0; + virtual Status DefaultAction(HloInstructionPtr hlo_instruction) = 0; - Status HandleElementwiseUnary(HloInstruction* hlo) override { + Status HandleElementwiseUnary(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleElementwiseBinary(HloInstruction* hlo) override { + Status HandleElementwiseBinary(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormTraining(HloInstruction* hlo) override { + Status HandleBatchNormTraining(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormInference(HloInstruction* hlo) override { + Status HandleBatchNormInference(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleBatchNormGrad(HloInstruction* hlo) override { + Status HandleBatchNormGrad(HloInstructionPtr hlo) override { return DefaultAction(hlo); } - Status HandleClamp(HloInstruction* clamp, HloInstruction* /*min*/, - HloInstruction* /*arg*/, - HloInstruction* /*max*/) override { + Status HandleClamp(HloInstructionPtr clamp) override { return DefaultAction(clamp); } - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice /*operands*/) override { + Status HandleConcatenate(HloInstructionPtr concatenate) override { return DefaultAction(concatenate); } - Status HandleConvert(HloInstruction* convert) override { + Status HandleConvert(HloInstructionPtr convert) override { return DefaultAction(convert); } - Status HandleCopy(HloInstruction* copy) override { + Status HandleCopy(HloInstructionPtr copy) override { return DefaultAction(copy); } - Status HandleSelect(HloInstruction* select, HloInstruction* /*pred*/, - HloInstruction* /*on_true*/, - HloInstruction* /*on_false*/) override { + Status HandleSelect(HloInstructionPtr select) override { return DefaultAction(select); } - Status HandleDot(HloInstruction* dot, HloInstruction* /*lhs*/, - HloInstruction* /*rhs*/) override { + Status HandleDot(HloInstructionPtr dot) override { return DefaultAction(dot); } - Status HandleConvolution(HloInstruction* convolution, HloInstruction* /*lhs*/, - HloInstruction* /*rhs*/, - const Window& /*window*/) override { + Status HandleConvolution(HloInstructionPtr convolution) override { return DefaultAction(convolution); } - Status HandleCrossReplicaSum(HloInstruction* crs) override { + Status HandleCrossReplicaSum(HloInstructionPtr crs) override { return DefaultAction(crs); } - Status HandleCompare(HloInstruction* compare, HloOpcode /*opcode*/, - HloInstruction* /*lhs*/, - HloInstruction* /*rhs*/) override { + Status HandleCompare(HloInstructionPtr compare) override { return DefaultAction(compare); } - Status HandleRng(HloInstruction* random, - RandomDistribution /*distribution*/) override { + Status HandleRng(HloInstructionPtr random) override { return DefaultAction(random); } - Status HandleInfeed(HloInstruction* infeed) override { + Status HandleInfeed(HloInstructionPtr infeed) override { return DefaultAction(infeed); } - Status HandleOutfeed(HloInstruction* outfeed) override { + Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } - Status HandleReverse(HloInstruction* reverse, - HloInstruction* /*operand*/) override { + Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } - Status HandleSort(HloInstruction* sort, - HloInstruction* /*operand*/) override { + Status HandleSort(HloInstructionPtr sort) override { return DefaultAction(sort); } - Status HandleConstant(HloInstruction* constant, - const Literal& /*literal*/) override { + Status HandleConstant(HloInstructionPtr constant) override { return DefaultAction(constant); } - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* /*operand*/) override { + Status HandleGetTupleElement(HloInstructionPtr get_tuple_element) override { return DefaultAction(get_tuple_element); } - Status HandleParameter(HloInstruction* parameter) override { + Status HandleParameter(HloInstructionPtr parameter) override { return DefaultAction(parameter); } - Status HandleFusion(HloInstruction* fusion) override { + Status HandleFusion(HloInstructionPtr fusion) override { return DefaultAction(fusion); } - Status HandleCall(HloInstruction* call) override { + Status HandleCall(HloInstructionPtr call) override { return DefaultAction(call); } - Status HandleCustomCall( - HloInstruction* custom_call, - tensorflow::gtl::ArraySlice /*operands*/, - tensorflow::StringPiece /*custom_call_target*/) override { + Status HandleCustomCall(HloInstructionPtr custom_call) override { return DefaultAction(custom_call); } - Status HandleSlice(HloInstruction* slice, - HloInstruction* /*operand*/) override { + Status HandleSlice(HloInstructionPtr slice) override { return DefaultAction(slice); } - Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* /*operand*/, - HloInstruction* /*start_indices*/) override { + Status HandleDynamicSlice(HloInstructionPtr dynamic_slice) override { return DefaultAction(dynamic_slice); } - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* /*operand*/, - HloInstruction* /*update*/, - HloInstruction* /*start_indices*/) override { + Status HandleDynamicUpdateSlice( + HloInstructionPtr dynamic_update_slice) override { return DefaultAction(dynamic_update_slice); } - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice /*operands*/) override { + Status HandleTuple(HloInstructionPtr tuple) override { return DefaultAction(tuple); } - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice /*operands*/, - HloComputation* /*function*/, - tensorflow::gtl::ArraySlice /*static_operands*/) - override { + Status HandleMap(HloInstructionPtr map) override { return DefaultAction(map); } - Status HandleReduce(HloInstruction* reduce, HloInstruction* /*arg*/, - HloInstruction* /*init_value*/, - tensorflow::gtl::ArraySlice /*dimensions*/, - HloComputation* /*function*/) override { + Status HandleReduce(HloInstructionPtr reduce) override { return DefaultAction(reduce); } - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* /*operand*/, - const Window& /*window*/, - HloComputation* /*function*/) override { + Status HandleReduceWindow(HloInstructionPtr reduce_window) override { return DefaultAction(reduce_window); } - Status HandleSelectAndScatter(HloInstruction* select_and_scatter) override { + Status HandleSelectAndScatter(HloInstructionPtr select_and_scatter) override { return DefaultAction(select_and_scatter); } - Status HandleBitcast(HloInstruction* bitcast) override { + Status HandleBitcast(HloInstructionPtr bitcast) override { return DefaultAction(bitcast); } - Status HandleBroadcast(HloInstruction* broadcast) override { + Status HandleBroadcast(HloInstructionPtr broadcast) override { return DefaultAction(broadcast); } - Status HandlePad(HloInstruction* pad) override { return DefaultAction(pad); } - Status HandleReshape(HloInstruction* reshape) override { + Status HandlePad(HloInstructionPtr pad) override { + return DefaultAction(pad); + } + Status HandleReshape(HloInstructionPtr reshape) override { return DefaultAction(reshape); } - Status HandleTranspose(HloInstruction* transpose) override { + Status HandleTranspose(HloInstructionPtr transpose) override { return DefaultAction(transpose); } - Status HandleWhile(HloInstruction* xla_while) override { + Status HandleWhile(HloInstructionPtr xla_while) override { return DefaultAction(xla_while); } - Status HandleSend(HloInstruction* send) override { + Status HandleSend(HloInstructionPtr send) override { return DefaultAction(send); } - Status HandleRecv(HloInstruction* recv) override { + Status HandleRecv(HloInstructionPtr recv) override { return DefaultAction(recv); } // Invoked to inform the visitor that the traversal has completed, and that // the root was "root". - Status FinishVisit(HloInstruction* /*root*/) override { return Status::OK(); } + Status FinishVisit(HloInstructionPtr /*root*/) override { + return Status::OK(); + } private: - TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefault); + TF_DISALLOW_COPY_AND_ASSIGN(DfsHloVisitorWithDefaultBase); }; -// Helper class for Accept(VisitorFunction) which visits instructions in DFS -// order calling the given function at each instruction. -class FunctionVisitor : public DfsHloVisitorWithDefault { +// Users should use these type aliases which are only two valid instantiations. +using DfsHloVisitorWithDefault = DfsHloVisitorWithDefaultBase; +using ConstDfsHloVisitorWithDefault = + DfsHloVisitorWithDefaultBase; + +// (Const)FunctionVisitor lets you transform an +// std::function into a (Const)DfsHloVisitor. +// +// This is useful if you have code that needs to handle visitors in the form of +// both std::function and DfsHloVisitor. You can wrap the function in a +// FunctionVisitor and then treat it like any other DfsHloVisitor. +template +class FunctionVisitorBase + : public DfsHloVisitorWithDefaultBase { public: - using VisitorFunction = std::function; - explicit FunctionVisitor(VisitorFunction visitor_func) + explicit FunctionVisitorBase( + std::function visitor_func) : visitor_func_(std::move(visitor_func)) {} - Status DefaultAction(HloInstruction* hlo_instruction) override { + Status DefaultAction(HloInstructionPtr hlo_instruction) override { return visitor_func_(hlo_instruction); } private: - VisitorFunction visitor_func_; + TF_DISALLOW_COPY_AND_ASSIGN(FunctionVisitorBase); + + std::function visitor_func_; }; +using FunctionVisitor = FunctionVisitorBase; +using ConstFunctionVisitor = FunctionVisitorBase; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_DFS_HLO_VISITOR_WITH_DEFAULT_H_ diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 1b1aef3cdbe087aad38194f36aff700986f4747b..fd4c332cba94513ec5b4cd88a842189e716f35d5 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -44,17 +44,22 @@ limitations under the License. namespace xla { +using llvm_ir::AsStringRef; using llvm_ir::IrArray; +using llvm_ir::IrName; using llvm_ir::SetToFirstInsertPoint; +using tensorflow::strings::StrCat; StatusOr ElementalIrEmitter::EmitUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const { if (op->opcode() == HloOpcode::kCopy) { return operand_value; + } else if (operand_value->getType()->isIntegerTy()) { + return EmitIntegerUnaryOp(op, operand_value); + } else if (ShapeUtil::ElementIsComplex(op->operand(0)->shape())) { + return EmitComplexUnaryOp(op, operand_value); } else { - return operand_value->getType()->isIntegerTy() - ? EmitIntegerUnaryOp(op, operand_value) - : EmitFloatUnaryOp(op, operand_value); + return EmitFloatUnaryOp(op, operand_value); } } @@ -70,20 +75,35 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } if (primitive_util::IsIntegralType(to_type)) { return ir_builder_->CreateIntCast( - operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_), + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_), primitive_util::IsSignedIntegralType(to_type)); } if (primitive_util::IsFloatingPointType(to_type)) { if (primitive_util::IsSignedIntegralType(from_type)) { return ir_builder_->CreateSIToFP( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(from_type) || from_type == PRED) { return ir_builder_->CreateUIToFP( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); + } + } + if (primitive_util::IsComplexType(to_type)) { + auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType( + primitive_util::ComplexComponentType(to_type), module_); + if (primitive_util::IsSignedIntegralType(from_type)) { + return ComposeComplex( + op, + ir_builder_->CreateSIToFP(operand_value, to_ir_component_type), + nullptr); + } + if (primitive_util::IsUnsignedIntegralType(from_type) || + from_type == PRED) { + return ComposeComplex( + op, + ir_builder_->CreateUIToFP(operand_value, to_ir_component_type), + nullptr); } } return Unimplemented("conversion from primitive type %s to %s", @@ -94,8 +114,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); if (is_signed) { - auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), - ir_builder_); + auto type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero); return ir_builder_->CreateSelect(cmp, operand_value, @@ -107,8 +127,8 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( case HloOpcode::kSign: { bool is_signed = primitive_util::IsSignedIntegralType(op->shape().element_type()); - auto type = llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), - ir_builder_); + auto type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); auto zero = llvm::ConstantInt::get(type, 0); auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero); if (is_signed) { @@ -123,14 +143,21 @@ StatusOr ElementalIrEmitter::EmitIntegerUnaryOp( } case HloOpcode::kNegate: return ir_builder_->CreateNeg(operand_value); - case HloOpcode::kLogicalNot: - // It is not sufficient to just call CreateNot() here because a PRED is - // represented as an i8 and the truth value is stored only in the bottom - // bit. - return ir_builder_->CreateZExt( - ir_builder_->CreateNot(ir_builder_->CreateTrunc( - operand_value, ir_builder_->getInt1Ty())), - llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + case HloOpcode::kNot: { + auto type = op->shape().element_type(); + if (type == PRED) { + // It is not sufficient to just call CreateNot() here because a PRED + // is represented as an i8 and the truth value is stored only in the + // bottom bit. + return ir_builder_->CreateZExt( + ir_builder_->CreateNot(ir_builder_->CreateTrunc( + operand_value, ir_builder_->getInt1Ty())), + llvm_ir::PrimitiveTypeToIrType(PRED, module_)); + } else if (primitive_util::IsIntegralType(type)) { + return ir_builder_->CreateNot(operand_value); + } + return Unimplemented("unary op Not is not defined for type '%d'", type); + } default: return Unimplemented("unary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -147,20 +174,30 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( if (from_type == to_type) { return operand_value; } + if (primitive_util::IsComplexType(to_type)) { + PrimitiveType to_component_type = + primitive_util::ComplexComponentType(to_type); + if (from_type == to_component_type) { + return ComposeComplex(op, operand_value, nullptr); + } + return ComposeComplex( + op, + ir_builder_->CreateFPCast( + operand_value, + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)), + nullptr); + } if (primitive_util::IsFloatingPointType(to_type)) { return ir_builder_->CreateFPCast( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsSignedIntegralType(to_type)) { return ir_builder_->CreateFPToSI( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } if (primitive_util::IsUnsignedIntegralType(to_type)) { return ir_builder_->CreateFPToUI( - operand_value, - llvm_ir::PrimitiveTypeToIrType(to_type, ir_builder_)); + operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_)); } return Unimplemented("unhandled conversion operation: %s => %s", PrimitiveType_Name(from_type).c_str(), @@ -220,7 +257,7 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity); auto result_i1 = ir_builder_->CreateAnd(equal_self, not_infinite); return ir_builder_->CreateZExt( - result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder_)); + result_i1, llvm_ir::PrimitiveTypeToIrType(PRED, module_)); } case HloOpcode::kNegate: return ir_builder_->CreateFNeg(operand_value); @@ -230,20 +267,164 @@ StatusOr ElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr ElementalIrEmitter::EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + switch (op->opcode()) { + // TODO(b/65209142): Angle/Log require atan2. + // case HloOpcode::kAngle: + // case HloOpcode::kLog: // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + case HloOpcode::kConvert: { + PrimitiveType from_type = op->operand(0)->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(from_type)); + PrimitiveType to_type = op->shape().element_type(); + TF_RET_CHECK(primitive_util::IsComplexType(to_type)); + if (from_type == to_type) { + return operand_value; + } + PrimitiveType to_component_type = + primitive_util::ComplexComponentType(to_type); + auto to_ir_component_type = + llvm_ir::PrimitiveTypeToIrType(to_component_type, module_); + return ComposeComplex( + op, + ir_builder_->CreateFPCast(real(operand_value), to_ir_component_type), + ir_builder_->CreateFPCast(imag(operand_value), to_ir_component_type)); + } + case HloOpcode::kExp: { + // e^(a+bi) = e^a*(cos(b)+sin(b)i) + auto exp_a = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::exp, {real(operand_value)}, + {real(operand_value)->getType()}, ir_builder_); + auto cos_b = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::cos, {imag(operand_value)}, + {imag(operand_value)->getType()}, ir_builder_); + auto sin_b = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::sin, {imag(operand_value)}, + {imag(operand_value)->getType()}, ir_builder_); + return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); + } + case HloOpcode::kCos: { + // cos(z) = .5(e^(iz) + e^(-iz)) + // cos(a+bi) = .5(e^(-b+ai) + e^(b-ai)) + // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have + // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(-a)+sin(-a)i)) + // cos(-x) = cos(x) and sin(-x) = -sin(x), so + // cos(a+bi) = .5(e^-b*(cos(a)+sin(a)i) + e^b*(cos(a)-sin(a)i)) + // = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) + auto a = real(operand_value); + auto b = imag(operand_value); + auto type = a->getType(); + auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, + {type}, ir_builder_); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, + {type}, ir_builder_); + auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, + {type}, ir_builder_); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + } + case HloOpcode::kSin: { + // sin(z) = .5i(e^(-iz) - e^(iz)) + // sin(a+bi) = .5i(e^(-i(a+bi)) - e^(i(a+bi))) + // = .5i(e^(b-ai) - e^(-b+ai)) + // now, e^(x+yi) = e^x*(cos(y)+sin(y)i), so we have + // sin(a+bi) = 0.5i(e^b*(cos(-a)+sin(-a)i) - e^-b*(cos(a)+sin(a)i)) + // = 0.5(e^b*(cos(-a)i-sin(-a)) - e^-b*(cos(a)i-sin(a))) + // cos(-x) = cos(x) and sin(-x) = -sin(x), so + // = 0.5(e^b*(cos(a)i+sin(a)) - e^-b*(cos(a)i-sin(a))) + // = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) + auto a = real(operand_value); + auto b = imag(operand_value); + auto type = a->getType(); + auto exp_b = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {b}, + {type}, ir_builder_); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b); + auto cos_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {a}, + {type}, ir_builder_); + auto sin_a = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {a}, + {type}, ir_builder_); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + } + case HloOpcode::kAbs: { + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(operand_value), real(operand_value)), + ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq}, + {sum_sq->getType()}, ir_builder_); + } + case HloOpcode::kSign: { // Sign(c) = c / |c| + auto sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(operand_value), real(operand_value)), + ir_builder_->CreateFMul(imag(operand_value), imag(operand_value))); + auto cplx_abs = llvm_ir::EmitCallToIntrinsic( + llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_); + auto type = cplx_abs->getType(); + auto zero = llvm::ConstantFP::get(type, 0.0); + auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero); + return ir_builder_->CreateSelect( + oeq, ComposeComplex(op, zero, zero), + ComposeComplex( + op, ir_builder_->CreateFDiv(real(operand_value), cplx_abs), + ir_builder_->CreateFDiv(imag(operand_value), cplx_abs))); + } + case HloOpcode::kNegate: + return ComposeComplex(op, ir_builder_->CreateFNeg(real(operand_value)), + ir_builder_->CreateFNeg(imag(operand_value))); + case HloOpcode::kReal: + return real(operand_value); + case HloOpcode::kImag: + return imag(operand_value); + default: + return Unimplemented("unary complex op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + StatusOr ElementalIrEmitter::EmitBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { - return lhs_value->getType()->isIntegerTy() - ? EmitIntegerBinaryOp(op, lhs_value, rhs_value, - primitive_util::IsSignedIntegralType( - op->operand(0)->shape().element_type())) - : EmitFloatBinaryOp(op, lhs_value, rhs_value); + PrimitiveType operand_type = op->operand(0)->shape().element_type(); + if (lhs_value->getType()->isIntegerTy()) { + return EmitIntegerBinaryOp( + op, lhs_value, rhs_value, + primitive_util::IsSignedIntegralType(operand_type)); + } else if (primitive_util::IsComplexType(operand_type)) { + return EmitComplexBinaryOp(op, lhs_value, rhs_value); + } else { + return EmitFloatBinaryOp(op, lhs_value, rhs_value); + } } StatusOr ElementalIrEmitter::EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const { switch (op->opcode()) { + // case HloOpcode::kAtan2: // TODO(b/65209142): CPU atan2 support + case HloOpcode::kComplex: + return ComposeComplex(op, lhs_value, rhs_value); case HloOpcode::kAdd: return ir_builder_->CreateFAdd(lhs_value, rhs_value); case HloOpcode::kSubtract: @@ -295,6 +476,88 @@ StatusOr ElementalIrEmitter::EmitFloatBinaryOp( } } +StatusOr ElementalIrEmitter::EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + switch (op->opcode()) { + case HloOpcode::kAdd: + return ComposeComplex( + op, ir_builder_->CreateFAdd(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFAdd(imag(lhs_value), imag(rhs_value))); + case HloOpcode::kSubtract: + return ComposeComplex( + op, ir_builder_->CreateFSub(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFSub(imag(lhs_value), imag(rhs_value))); + case HloOpcode::kMultiply: + return ComposeComplex( + op, + ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))), + ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)))); + case HloOpcode::kDivide: { + // (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di)) + // = ((ac + bd) + (bc - ad)i) / (c^2 + d^2) + auto rhs_sum_sq = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(rhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(rhs_value), imag(rhs_value))); + auto type = rhs_sum_sq->getType(); + auto zero = llvm::ConstantFP::get(type, 0.0); + auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero); + return ir_builder_->CreateSelect( + oeq, ComposeComplex(op, llvm::ConstantFP::getInfinity(type), zero), + ComposeComplex( + op, + ir_builder_->CreateFDiv( + ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), + imag(rhs_value))), + rhs_sum_sq), + ir_builder_->CreateFDiv( + ir_builder_->CreateFSub( + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(real(lhs_value), + imag(rhs_value))), + rhs_sum_sq))); + } + // LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered + // comparisons always return false when one of the operands is NaN, whereas + // unordered comparisons return true. + // + // We use ordered comparisons for everything except kNe, where we use an + // unordered comparison. This makes x != y equivalent to !(x == y), and + // matches C++'s semantics. + case HloOpcode::kEq: + return ir_builder_->CreateAnd( + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, real(lhs_value), + real(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, imag(lhs_value), + imag(rhs_value), ir_builder_)); + case HloOpcode::kNe: + return ir_builder_->CreateOr( + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, real(lhs_value), + real(rhs_value), ir_builder_), + llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, imag(lhs_value), + imag(rhs_value), ir_builder_)); + + // TODO(b/65209142): requires arg(z) -> requires atan|atan2 intrinsic + // case HloOpcode::kPower: + // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(c/2+di/2) + default: + return Unimplemented("binary complex op '%s'", + HloOpcodeString(op->opcode()).c_str()); + } +} + llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const { return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_); @@ -386,7 +649,7 @@ StatusOr ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type, StatusOr ElementalIrEmitter::EmitErfcInv( PrimitiveType prim_type, llvm::Value* value) const { // Compute erfcinv(value) by calculating erfinv(1.0 - value). - auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, ir_builder_); + auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_); auto one = llvm::ConstantFP::get(type, 1.0); return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value)); } @@ -554,10 +817,16 @@ StatusOr ElementalIrEmitter::EmitIntegerBinaryOp( is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE, lhs_value, rhs_value), lhs_value, rhs_value); - case HloOpcode::kLogicalAnd: + case HloOpcode::kAnd: return ir_builder_->CreateAnd(lhs_value, rhs_value); - case HloOpcode::kLogicalOr: + case HloOpcode::kOr: return ir_builder_->CreateOr(lhs_value, rhs_value); + case HloOpcode::kShiftLeft: + return ir_builder_->CreateShl(lhs_value, rhs_value); + case HloOpcode::kShiftRightArithmetic: + return ir_builder_->CreateAShr(lhs_value, rhs_value); + case HloOpcode::kShiftRightLogical: + return ir_builder_->CreateLShr(lhs_value, rhs_value); default: return Unimplemented("binary integer op '%s'", HloOpcodeString(op->opcode()).c_str()); @@ -603,7 +872,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( const { PrimitiveType param_prim_type = hlo->operand(0)->shape().element_type(); llvm::Type* param_ir_type = - llvm_ir::PrimitiveTypeToIrType(param_prim_type, ir_builder_); + llvm_ir::PrimitiveTypeToIrType(param_prim_type, module_); // Same values as PCG library // https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h @@ -721,9 +990,9 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( if (ir_builder_->GetInsertPoint() == in_block->end()) { body_block = llvm_ir::CreateBasicBlock( - nullptr, llvm_ir::IrName(hlo, "rng_body"), ir_builder_); + nullptr, IrName(hlo, "rng_body"), ir_builder_); out_block = llvm_ir::CreateBasicBlock( - nullptr, llvm_ir::IrName(hlo, "rng_out"), ir_builder_); + nullptr, IrName(hlo, "rng_out"), ir_builder_); llvm::BranchInst::Create(body_block, in_block); } else { body_block = in_block->splitBasicBlock( @@ -767,7 +1036,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( return ir_builder_->CreateZExt( ir_builder_->CreateFCmpOLT(get_next_uniform_float(), p), llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_)); + module_)); } default: return InvalidArgument( @@ -790,13 +1059,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { TF_ASSIGN_OR_RETURN(llvm::Value * operand_value, @@ -805,6 +1076,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( return EmitUnaryOp(hlo, operand_value); }; case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kGe: @@ -818,8 +1091,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: return [this, hlo, &operand_to_generator]( const IrArray::Index& index) -> StatusOr { const HloInstruction* lhs = hlo->operand(0); @@ -876,28 +1152,40 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( const int64 concat_dim = hlo->dimensions(0); auto source_index = target_index; - llvm::PHINode* output = ir_builder_->CreatePHI( - llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), - hlo->operands().size()); llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock(); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(), + init_block->getTerminator() == nullptr); + + llvm::BasicBlock* exit_block; + if (ir_builder_->GetInsertPoint() == init_block->end()) { + exit_block = llvm_ir::CreateBasicBlock( + /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_); + } else { + exit_block = init_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), AsStringRef(IrName(hlo, "merge"))); + init_block->getTerminator()->eraseFromParent(); + } + + llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_); + llvm::PHINode* output = + ir_builder_->CreatePHI(llvm_ir::PrimitiveTypeToIrType( + hlo->shape().element_type(), module_), + hlo->operands().size()); auto prior_insert_point = ir_builder_->GetInsertPoint(); - llvm::BasicBlock* exit_block = - init_block->splitBasicBlock(output, "concat_merge"); ir_builder_->SetInsertPoint(init_block); - init_block->getTerminator()->eraseFromParent(); for (int64 operand_idx = 0; operand_idx < hlo->operand_count(); ++operand_idx) { const HloInstruction* operand = hlo->operand(operand_idx); auto true_block = llvm_ir::CreateBasicBlock( - exit_block, tensorflow::strings::StrCat( - "concat_index_from_operand", operand_idx), + exit_block, StrCat("concat_index_from_operand", operand_idx), ir_builder_); auto false_block = llvm_ir::CreateBasicBlock( - exit_block, tensorflow::strings::StrCat( - "concat_index_not_from_operand", operand_idx), + exit_block, StrCat("concat_index_not_from_operand", operand_idx), ir_builder_); auto concat_dim_size = llvm::ConstantInt::get(source_index[concat_dim]->getType(), @@ -972,6 +1260,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN( llvm::Value * start_index_value, operand_to_generator.at(hlo->operand(1))(dim_index)); + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = start_index_value; } @@ -1004,6 +1294,8 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( llvm_ir::IrArray::Index dim_index(1, ir_builder_->getInt64(i)); TF_ASSIGN_OR_RETURN(llvm::Value * start_index_value, operand_to_generator.at(start_hlo)(dim_index)); + start_index_value->setName( + AsStringRef(IrName(hlo, StrCat("start_idx", i)))); slice_start_index[i] = ir_builder_->CreateZExtOrBitCast( start_index_value, index[i]->getType()); // Emit IR to compute: slice_limit_index = start_index + update_dim @@ -1040,7 +1332,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( // else -> return data from 'index'. llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), + module_), "ret_value_addr", ir_builder_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse( slice_intersection, "slice_intersection", ir_builder_); @@ -1129,7 +1421,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( // } llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), - ir_builder_), + module_), "pad_result_addr", ir_builder_); llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); @@ -1163,7 +1455,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( std::unique_ptr inner_loop = llvm_ir::ForLoop::EmitForLoop( - llvm_ir::IrName(hlo, "inner"), ir_builder_->getInt64(0), + IrName(hlo, "inner"), ir_builder_->getInt64(0), ir_builder_->getInt64(contracted_dim_size), ir_builder_->getInt64(1), ir_builder_); @@ -1171,7 +1463,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( ir_builder_); PrimitiveType primitive_type = hlo->shape().element_type(); llvm::Type* primitive_type_llvm = - llvm_ir::PrimitiveTypeToIrType(primitive_type, ir_builder_); + llvm_ir::PrimitiveTypeToIrType(primitive_type, module_); llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry( primitive_type_llvm, "dot_acc", ir_builder_); ir_builder_->CreateStore( @@ -1204,7 +1496,28 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index)); TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index)); llvm::Value* next_accumulator; - if (primitive_util::IsFloatingPointType(primitive_type)) { + if (primitive_util::IsComplexType(primitive_type)) { + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + llvm::Value* product_real = ir_builder_->CreateFSub( + ir_builder_->CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), imag(rhs_value))); + llvm::Value* product_imag = ir_builder_->CreateFAdd( + ir_builder_->CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_->CreateFMul(imag(lhs_value), real(rhs_value))); + next_accumulator = ir_builder_->CreateInsertValue( + current_accumulator, + ir_builder_->CreateFAdd(real(current_accumulator), product_real), + {0}); + next_accumulator = ir_builder_->CreateInsertValue( + next_accumulator, + ir_builder_->CreateFAdd(imag(current_accumulator), product_imag), + {1}); + } else if (primitive_util::IsFloatingPointType(primitive_type)) { next_accumulator = ir_builder_->CreateFAdd( current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value)); @@ -1226,4 +1539,17 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator( } } +llvm::Value* ElementalIrEmitter::ComposeComplex(const HloInstruction* op, + llvm::Value* real, + llvm::Value* imag) const { + auto cplx_type = + llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_); + auto complex = ir_builder_->CreateInsertValue( + llvm::ConstantAggregateZero::get(cplx_type), real, {0}); + if (imag != nullptr) { + complex = ir_builder_->CreateInsertValue(complex, imag, {1}); + } + return complex; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 35dfa88e9b02e3ec7686dc7fdded8cf4e88201fb..9d32436e38fa2fb3e27d09f01b860cd2edf2c8ac 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -55,6 +55,7 @@ class ElementalIrEmitter { const HloToElementGeneratorMap& operand_to_generator) const; llvm::IRBuilder<>* ir_builder() const { return ir_builder_; } + llvm::Module* module() const { return module_; } protected: virtual StatusOr EmitIntegerUnaryOp( @@ -63,6 +64,9 @@ class ElementalIrEmitter { virtual StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const; + virtual StatusOr EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, @@ -72,6 +76,10 @@ class ElementalIrEmitter { const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const; + virtual StatusOr EmitComplexBinaryOp( + const HloInstruction* op, llvm::Value* lhs_value, + llvm::Value* rhs_value) const; + virtual llvm::Value* EmitFloatMax(llvm::Value* lhs_value, llvm::Value* rhs_value) const; @@ -109,6 +117,11 @@ class ElementalIrEmitter { // compiled executable outside of the HLO code itself. const HloModuleConfig& hlo_module_config_; + protected: + // Composes a complex struct. imag may be nullptr for simple cast operations. + llvm::Value* ComposeComplex(const HloInstruction* op, llvm::Value* real, + llvm::Value* imag) const; + private: // Returns a ElementGenerator for a RNG HloInstruction. llvm_ir::ElementGenerator MakeRngElementGenerator( diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc index 79fedb61c971862fc0e3a59e01e55825f09c587d..9c96d9eb30b5f9e51b7f5d82391c6b9f366898d6 100644 --- a/tensorflow/compiler/xla/service/executable.cc +++ b/tensorflow/compiler/xla/service/executable.cc @@ -17,7 +17,9 @@ limitations under the License. #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" +#include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/lib/hash/hash.h" #include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/stringprintf.h" #include "tensorflow/core/platform/env.h" @@ -69,15 +71,6 @@ Status Executable::DumpSessionModule() { *session_module_); } -// Removes illegal characters from filenames. -static void SanitizeFilename(string* name) { - for (char& c : *name) { - if (c == '/' || c == '\\' || c == '[' || c == ']') { - c = '_'; - } - } -} - /* static */ Status Executable::DumpToDirectory( const string& directory_path, string filename, const SessionModule& session_module) { @@ -89,9 +82,13 @@ static void SanitizeFilename(string* name) { // "directory already exists" error. TF_RETURN_IF_ERROR(env->RecursivelyCreateDir(directory_path)); } - SanitizeFilename(&filename); + filename = SanitizeFileName(std::move(filename)); string file_path = tensorflow::io::JoinPath(directory_path, filename); - return tensorflow::WriteBinaryProto(env, file_path, session_module); + string result; + TF_RET_CHECK( + tensorflow::SerializeToStringDeterministic(session_module, &result)); + return tensorflow::WriteStringToFile(tensorflow::Env::Default(), file_path, + result); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h index b58dee9c20a6431968358fe90babf2fa813e7e11..7e0d182b365c35788195e70dc35c3923ed8991bb 100644 --- a/tensorflow/compiler/xla/service/executable.h +++ b/tensorflow/compiler/xla/service/executable.h @@ -44,7 +44,7 @@ namespace xla { // interface that is used for launching compiled programs across platforms. class Executable { public: - explicit Executable(std::unique_ptr hlo_module) + explicit Executable(std::unique_ptr hlo_module) : hlo_module_(std::move(hlo_module)) {} virtual ~Executable() {} @@ -88,6 +88,16 @@ class Executable { tensorflow::gtl::ArraySlice> arguments); + // Populates `hlo_execution_profile` from `executor`. This is implicit in any + // Execute* API call that takes a hlo_execution_profile argument, but must be + // called explicitly for other (async, for example) variants after the stream + // has completed. + virtual Status PopulateExecutionProfile( + HloExecutionProfile* hlo_execution_profile, + perftools::gputools::StreamExecutor* executor) { + return Status::OK(); + } + // Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a // timer for the execution, sets up HLO profiling if enabled, and fills in the // given ExecutionProfile if non-null. The ExecuteOnStream overloads have @@ -163,7 +173,7 @@ class Executable { // HloModule this was compiled from. BufferAssignment keeps pointers to // HloInstructions owned by the HloModule so we need to keep the HloModule // around. - std::unique_ptr hlo_module_; + const std::unique_ptr hlo_module_; // SessionModule this was compiled from. Null if not dumping executions. std::unique_ptr session_module_; diff --git a/tensorflow/compiler/xla/service/flatten_call_graph.cc b/tensorflow/compiler/xla/service/flatten_call_graph.cc index 297a4f7599f9c127386b2f53f7ffb987befc456e..dfba22a6c4c5cf071c2cd8621643b8da6587ee3b 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph.cc @@ -80,15 +80,15 @@ Status FlattenNode(const CallGraphNode& node) { while (!worklist.empty()) { auto current = worklist.back(); worklist.pop_back(); - for (auto& instruction : current->instructions()) { - if (GetInstructionCallContext(instruction.get()) != + for (auto* instruction : current->instructions()) { + if (GetInstructionCallContext(instruction) != CallContext::kSequential) { continue; } for (auto callee : instruction->called_computations()) { HloComputation* callee_clone = module->AddEmbeddedComputation(callee->Clone()); - ReplaceCalledComputation(instruction.get(), callee, callee_clone); + ReplaceCalledComputation(instruction, callee, callee_clone); worklist.push_back(callee_clone); } } diff --git a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc index bae122765971a517b9539c473a6f2a86be443a63..a68e90b7d009890012f94baa790d911871c9c960 100644 --- a/tensorflow/compiler/xla/service/flatten_call_graph_test.cc +++ b/tensorflow/compiler/xla/service/flatten_call_graph_test.cc @@ -214,7 +214,7 @@ TEST_F(FlattenCallGraphTest, FlattenCalls) { TF_ASSERT_OK_AND_ASSIGN(bool result, RunFlattenCallGraph(module.get())); EXPECT_TRUE(result); std::unique_ptr call_graph = CallGraph::Build(module.get()); - EXPECT_EQ(7, module->computations().size()); + EXPECT_EQ(7, module->computation_count()); const CallGraphNode& c_node = call_graph->GetNode(c_computation); EXPECT_EQ(1, c_node.caller_callsites().size()); diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc index 432df46eadd146ce83de6e16597a3fa04493188d..d3c83ea72e33b959e21d0cc9c1706d92bd659a5c 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc @@ -35,8 +35,9 @@ namespace se = ::perftools::gputools; namespace xla { -GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id) - : platform_id_(platform_id) { +GenericTransferManager::GenericTransferManager(se::Platform::Id platform_id, + size_t pointer_size) + : platform_id_(platform_id), pointer_size_(pointer_size) { // We currently only support kHostPlatformId for CPU, kCudaPlatformId for // GPU and kInterpreterPlatformId for Interpreter. Before supporting other // platforms, we need to test this transfer manager on them. @@ -127,6 +128,23 @@ GenericTransferManager::ShallowCopyTupleFromDevice( return std::move(destination); } +Status GenericTransferManager::WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice elements, + const Shape& shape, perftools::gputools::DeviceMemoryBase* region) { + TF_RET_CHECK(elements.size() == ShapeUtil::TupleElementCount(shape)); + + std::vector element_pointers; + for (const se::DeviceMemoryBase& element : elements) { + element_pointers.push_back(element.opaque()); + } + int64 tuple_size = + ShapeUtil::ByteSizeOf(shape, /*pointer_size=*/sizeof(void*)); + + return TransferBufferToDevice(executor, tuple_size, element_pointers.data(), + region); +} + Status GenericTransferManager::TransferLiteralToDevice( se::StreamExecutor* executor, const Literal& literal, se::DeviceMemoryBase* destination) { diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h index 993312fef9d8c9bb8827e2f9b7fd09037d2e818e..26488d6ec651b75c753119a7ce818c692c6c03dd 100644 --- a/tensorflow/compiler/xla/service/generic_transfer_manager.h +++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h @@ -36,8 +36,8 @@ namespace xla { // infeed. class GenericTransferManager : public TransferManager { public: - explicit GenericTransferManager( - perftools::gputools::Platform::Id platform_id); + GenericTransferManager(perftools::gputools::Platform::Id platform_id, + size_t pointer_size); ~GenericTransferManager() override {} perftools::gputools::Platform::Id PlatformId() const override; @@ -71,12 +71,22 @@ class GenericTransferManager : public TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) override; + Status WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice + elements, + const Shape& shape, + perftools::gputools::DeviceMemoryBase* region) override; + int64 GetByteSizeRequirement(const Shape& shape) override; private: // The platform this transfer manager targets. const perftools::gputools::Platform::Id platform_id_; + // The size in bytes of pointers on this platform. + const size_t pointer_size_; + TF_DISALLOW_COPY_AND_ASSIGN(GenericTransferManager); }; diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 9939178aa3d15bd6051de5247ef5422fb514aeea..b9c4adce93a88cb48635993b6e9999528d78ec07 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -104,7 +104,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:alias_analysis", "//tensorflow/compiler/xla/service/llvm_ir:ir_array", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", - "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "@llvm//:core", ], @@ -147,6 +147,7 @@ cc_library( "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter", "//tensorflow/compiler/xla/service/llvm_ir:ops", + "//tensorflow/compiler/xla/service/llvm_ir:tuple_ops", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", @@ -397,6 +398,29 @@ cc_library( ], ) +cc_library( + name = "gpu_transfer_manager", + srcs = ["gpu_transfer_manager.cc"], + hdrs = ["gpu_transfer_manager.h"], + deps = [ + ":gpu_compiler", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/service/gpu:infeed_manager", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + "@llvm//:core", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + cc_library( name = "gpu_compiler", srcs = ["gpu_compiler.cc"], @@ -440,6 +464,8 @@ cc_library( "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", "//tensorflow/compiler/xla/service:transpose_folding", + "//tensorflow/compiler/xla/service:tuple_simplifier", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/compiler/xla/service/gpu/llvm_gpu_backend", "//tensorflow/compiler/xla/service/llvm_ir:llvm_util", "//tensorflow/core:cuda_libdevice_path", diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc index c598025b5e8f3ff72656ff370068bb0ff3a80f2a..5aaf072f9d2c95e2fff70a1c5337432a12a1aa48 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc @@ -72,8 +72,10 @@ MatchBackwardFilter(HloInstruction* conv) { // Step 2: match paddings and dimension numbers of the forward convolution. const ConvolutionDimensionNumbers& conv_dnums = conv->convolution_dimension_numbers(); - auto batch_dim = conv_dnums.batch_dimension(); - auto feature_dim = conv_dnums.feature_dimension(); + auto input_batch_dim = conv_dnums.input_batch_dimension(); + auto input_feature_dim = conv_dnums.input_feature_dimension(); + auto output_batch_dim = conv_dnums.output_batch_dimension(); + auto output_feature_dim = conv_dnums.output_feature_dimension(); auto spatial_dims = conv_dnums.spatial_dimensions(); for (const WindowDimension& window_dim : conv->window().dimensions()) { @@ -176,15 +178,17 @@ MatchBackwardFilter(HloInstruction* conv) { transpose = parent_computation->AddInstruction(HloInstruction::CreateTranspose( conv->shape(), conv, transpose_dimensions)); - TF_CHECK_OK(parent_computation->ReplaceUsesOfInstruction(conv, transpose)); + TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose)); } // Restore the dimension numbers of the backward convolution from the forward // convolution. The two activation dimensions are reversed (batch and // feature). ConvolutionDimensionNumbers backward_conv_dnums; - backward_conv_dnums.set_batch_dimension(feature_dim); - backward_conv_dnums.set_feature_dimension(batch_dim); + backward_conv_dnums.set_input_batch_dimension(input_feature_dim); + backward_conv_dnums.set_input_feature_dimension(input_batch_dim); + backward_conv_dnums.set_output_batch_dimension(output_feature_dim); + backward_conv_dnums.set_output_feature_dimension(output_batch_dim); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_spatial_dimensions(spatial_dims[i]); } @@ -198,9 +202,9 @@ MatchBackwardFilter(HloInstruction* conv) { // the dimension numbering of the weight gradients. This transposition maps // dimension i to PositionInContainer(transpose->dimensions(), i). backward_conv_dnums.set_kernel_input_feature_dimension( - PositionInContainer(transpose->dimensions(), batch_dim)); + PositionInContainer(transpose->dimensions(), output_batch_dim)); backward_conv_dnums.set_kernel_output_feature_dimension( - PositionInContainer(transpose->dimensions(), feature_dim)); + PositionInContainer(transpose->dimensions(), output_feature_dim)); for (int i = 0; i < spatial_dims.size(); ++i) { backward_conv_dnums.add_kernel_spatial_dimensions( PositionInContainer(transpose->dimensions(), spatial_dims[i])); @@ -275,7 +279,7 @@ MatchBackwardInput(HloInstruction* conv) { Window new_window = old_window; for (size_t i = 0; i < spatial_dims.size(); ++i) { // Restore backward convolution's padding config from the matched pattern. - // See the comment in tensorflow/core/kernels/conv_grad_ops.cc + // See the comment in tensorflow/core/kernels/conv_grad_tuple_ops.cc // for how we convert backward input convolution to a variant of forward // convolution. // @@ -392,9 +396,9 @@ MatchBackwardInput(HloInstruction* conv) { StatusOr ConvolutionFolding::Run(HloModule* module) { HloComputation* entry_computation = module->entry_computation(); std::vector convs; - for (const auto& hlo : entry_computation->instructions()) { + for (auto* hlo : entry_computation->instructions()) { if (hlo->opcode() == HloOpcode::kConvolution) { - convs.push_back(hlo.get()); + convs.push_back(hlo); } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc index 6699c8f3c4acd76ed58cccf314ca0ae1502d51d7..19b122ba0603b4ec08d73e05da4c2ae11a760553 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc @@ -45,8 +45,10 @@ class ConvolutionFoldingTest : public HloTestBase { // dimension in gradients as the input feature dimension in the filter. // // TODO(jingyue): Add more tests on NCHW input order which TF also supports. - tf_default_dnums_for_backward_filter_.set_batch_dimension(3); - tf_default_dnums_for_backward_filter_.set_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3); + tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0); + tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_spatial_dimensions(2); tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0); @@ -55,8 +57,10 @@ class ConvolutionFoldingTest : public HloTestBase { tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1); tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2); - tf_default_dnums_for_backward_input_.set_batch_dimension(0); - tf_default_dnums_for_backward_input_.set_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_input_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_output_batch_dimension(0); + tf_default_dnums_for_backward_input_.set_input_feature_dimension(3); + tf_default_dnums_for_backward_input_.set_output_feature_dimension(3); tf_default_dnums_for_backward_input_.add_spatial_dimensions(1); tf_default_dnums_for_backward_input_.add_spatial_dimensions(2); tf_default_dnums_for_backward_input_.set_kernel_input_feature_dimension(3); @@ -250,8 +254,10 @@ TEST_F(ConvolutionFoldingTest, BackwardInputConvolveEvenPadding) { conv_window.mutable_dimensions(i)->set_padding_high(3); } ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_input_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc index 89145a9038c23e02b1b25140ff3711dc44185d0c..536b96dcf620e908e25a775bc2efb57ba5f5edd6 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.cc @@ -141,8 +141,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor input_descriptor(effective_num_dimensions); input_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - input_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(input_shape_.dimensions(dim_nums_.batch_dimension())); + input_shape_.dimensions(dim_nums_.input_feature_dimension())) + .set_count(input_shape_.dimensions(dim_nums_.input_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { // Note that the dimensions are reversed. The same holds below. input_descriptor.set_spatial_dim( @@ -176,8 +176,8 @@ tensorflow::Status ConvolutionThunk::ExecuteOnStream( BatchDescriptor output_descriptor(effective_num_dimensions); output_descriptor.set_layout(DataLayout::kBatchDepthYX) .set_feature_map_count( - output_shape_.dimensions(dim_nums_.feature_dimension())) - .set_count(output_shape_.dimensions(dim_nums_.batch_dimension())); + output_shape_.dimensions(dim_nums_.output_feature_dimension())) + .set_count(output_shape_.dimensions(dim_nums_.output_batch_dimension())); for (int dim = 0; dim < num_dimensions; ++dim) { output_descriptor.set_spatial_dim( static_cast(effective_num_dimensions - dim - 1), @@ -256,9 +256,9 @@ tensorflow::Status ConvolutionThunk::Convolve( algorithm_config.algorithm_no_scratch().algo_id()); } -std::vector ConvolutionThunk::GetAlgorithms( +std::vector ConvolutionThunk::GetAlgorithms( se::StreamExecutor* stream_exec) const { - std::vector algorithms; + std::vector algorithms; // TODO(yangzihao): Currently disable the use of winograd nonfused in XLA // by default. Should send in conv parameters and enable it when // ShouldIncludeWinogradNonfusedAlgo() returns true. @@ -297,32 +297,27 @@ tensorflow::Status ConvolutionThunk::ConvolveWithTune( se::dnn::ProfileResult best_result; se::dnn::ProfileResult best_result_without_scratch; - std::vector algorithms = - GetAlgorithms(stream->parent()); - for (bool use_tensor_ops : {false, true}) { - for (auto algo_index : algorithms) { - AlgorithmDesc algorithm(algo_index, use_tensor_ops); - ConvolveScratchAllocator scratch_allocator( - buffer_allocations.device_ordinal(), - buffer_allocations.memory_allocator()); - se::dnn::ProfileResult profile_result; - bool launch_ok = - Convolve(input_descriptor, input_data, filter_descriptor, - filter_data, output_descriptor, output_data, - convolution_descriptor, - se::dnn::AlgorithmConfig(algorithm, algorithm), stream, - &scratch_allocator, &profile_result) - .ok(); - if (launch_ok && profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalAllocatedBytes() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_without_scratch.elapsed_time_in_ms()) { - best_result_without_scratch = profile_result; - } + std::vector algorithms = GetAlgorithms(stream->parent()); + for (auto algorithm : algorithms) { + ConvolveScratchAllocator scratch_allocator( + buffer_allocations.device_ordinal(), + buffer_allocations.memory_allocator()); + se::dnn::ProfileResult profile_result; + bool launch_ok = + Convolve(input_descriptor, input_data, filter_descriptor, filter_data, + output_descriptor, output_data, convolution_descriptor, + se::dnn::AlgorithmConfig(algorithm, algorithm), stream, + &scratch_allocator, &profile_result) + .ok(); + if (launch_ok && profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalAllocatedBytes() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_without_scratch.elapsed_time_in_ms()) { + best_result_without_scratch = profile_result; } } } diff --git a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h index 509719c1fe555fd733484c82ca14812efca0dcf9..13432301b2af34ab4bd0864e39ce22366cc1d11d 100644 --- a/tensorflow/compiler/xla/service/gpu/convolution_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/convolution_thunk.h @@ -115,9 +115,7 @@ class ConvolutionThunk : public Thunk { perftools::gputools::dnn::ProfileResult* profile_result); // Returns the convolve algorithms that can be used for this ConvolutionThunk. - // TODO(nluehr) GetAlgorithms should return AlgorithmDesc including both - // tensor-op and non-tensor-op variants. - std::vector GetAlgorithms( + std::vector GetAlgorithms( perftools::gputools::StreamExecutor* stream_exec) const; // Fastest cuDNN convolution algorithm for this thunk learned from diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc index 87858e94090d1f7506ee09b9015b4417aee55707..f4498663b1c039b3175376baf8f27c4ecec678ec 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.cc @@ -20,15 +20,16 @@ limitations under the License. namespace xla { namespace gpu { -CopyThunk::CopyThunk(const void* source_address, - const BufferAllocation::Slice& destination_buffer, - uint64 mem_size, const HloInstruction* hlo_instruction) +HostToDeviceCopyThunk::HostToDeviceCopyThunk( + const void* source_address, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, + const HloInstruction* hlo_instruction) : Thunk(Kind::kCopy, hlo_instruction), source_address_(source_address), destination_buffer_(destination_buffer), mem_size_(mem_size) {} -tensorflow::Status CopyThunk::ExecuteOnStream( +tensorflow::Status HostToDeviceCopyThunk::ExecuteOnStream( const BufferAllocations& buffer_allocations, perftools::gputools::Stream* stream) { perftools::gputools::DeviceMemoryBase destination_data = @@ -37,5 +38,24 @@ tensorflow::Status CopyThunk::ExecuteOnStream( return tensorflow::Status::OK(); } +DeviceToDeviceCopyThunk::DeviceToDeviceCopyThunk( + const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, uint64 mem_size, + const HloInstruction* hlo_instruction) + : Thunk(Kind::kCopy, hlo_instruction), + source_buffer_(source_buffer), + destination_buffer_(destination_buffer), + mem_size_(mem_size) {} + +tensorflow::Status DeviceToDeviceCopyThunk::ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) { + perftools::gputools::DeviceMemoryBase destination_data = + buffer_allocations.GetDeviceAddress(destination_buffer_); + perftools::gputools::DeviceMemoryBase source_data = + buffer_allocations.GetDeviceAddress(source_buffer_); + stream->ThenMemcpy(&destination_data, source_data, mem_size_); + return tensorflow::Status::OK(); +} } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/copy_thunk.h b/tensorflow/compiler/xla/service/gpu/copy_thunk.h index 6b8c432715f27fc02b13fc242db5ee6db098c47e..e2783fd255239d31edc89701ea208f33ebb8d3fb 100644 --- a/tensorflow/compiler/xla/service/gpu/copy_thunk.h +++ b/tensorflow/compiler/xla/service/gpu/copy_thunk.h @@ -26,19 +26,18 @@ limitations under the License. namespace xla { namespace gpu { -// A thunk that copies data. For now, it copies data only from host to device. -// But it can be extended to perform device-to-host or intra-device copying. -class CopyThunk : public Thunk { +// A thunk that copies data from a host buffer to a device buffer. +class HostToDeviceCopyThunk : public Thunk { public: // Constructs a CopyThunk that copies host data from `source_address` to the // device buffer `destination_buffer`. `mem_size` is the size of the data in // bytes. - CopyThunk(const void* source_address, - const BufferAllocation::Slice& destination_buffer, uint64 mem_size, - const HloInstruction* hlo_instruction); + HostToDeviceCopyThunk(const void* source_address, + const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, const HloInstruction* hlo_instruction); - CopyThunk(const CopyThunk&) = delete; - CopyThunk& operator=(const CopyThunk&) = delete; + HostToDeviceCopyThunk(const HostToDeviceCopyThunk&) = delete; + HostToDeviceCopyThunk& operator=(const HostToDeviceCopyThunk&) = delete; tensorflow::Status ExecuteOnStream( const BufferAllocations& buffer_allocations, @@ -50,6 +49,30 @@ class CopyThunk : public Thunk { const uint64 mem_size_; }; +// A thunk that copies data from a device buffer to another device buffer. +class DeviceToDeviceCopyThunk : public Thunk { + public: + // Constructs a CopyThunk that copies host data from `source_buffer` to the + // device buffer `destination_buffer`. `mem_size` is the size of the data in + // bytes. + DeviceToDeviceCopyThunk(const BufferAllocation::Slice& source_buffer, + const BufferAllocation::Slice& destination_buffer, + uint64 mem_size, + const HloInstruction* hlo_instruction); + + DeviceToDeviceCopyThunk(const DeviceToDeviceCopyThunk&) = delete; + DeviceToDeviceCopyThunk& operator=(const DeviceToDeviceCopyThunk&) = delete; + + tensorflow::Status ExecuteOnStream( + const BufferAllocations& buffer_allocations, + perftools::gputools::Stream* stream) override; + + private: + const BufferAllocation::Slice source_buffer_; + const BufferAllocation::Slice destination_buffer_; + const uint64 mem_size_; +}; + } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 8810a85ceeafd8b2d9ad8d7412266847abe5b75d..1b94499bc6ef6d587cdb1fafec48bc4e5b917c51 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -135,6 +135,10 @@ StatusOr GpuElementalIrEmitter::EmitFloatBinaryOp( PrimitiveType rhs_input_type = op->operand(1)->shape().element_type(); PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { + case HloOpcode::kAtan2: + return EmitLibdeviceMathCall("__nv_atan2", {lhs_value, rhs_value}, + {lhs_input_type, rhs_input_type}, + output_type); case HloOpcode::kRemainder: { return EmitLibdeviceMathCall("__nv_fmod", {lhs_value, rhs_value}, {lhs_input_type, rhs_input_type}, @@ -226,6 +230,112 @@ StatusOr GpuElementalIrEmitter::EmitFloatUnaryOp( } } +StatusOr GpuElementalIrEmitter::EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const { + PrimitiveType input_type = op->operand(0)->shape().element_type(); + PrimitiveType component_type = + primitive_util::IsComplexType(input_type) + ? primitive_util::ComplexComponentType(input_type) + : input_type; + auto real = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_->CreateExtractValue(x, {1}); + }; + + switch (op->opcode()) { + case HloOpcode::kLog: { + // log(a+bi) = .5*log(a^2+b^2) + i*atan2(b, a) + auto a = real(operand_value); + auto b = imag(operand_value); + llvm::Type* llvm_ty = a->getType(); + auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a), + ir_builder_->CreateFMul(b, b)); + TF_ASSIGN_OR_RETURN( + auto log_sum_sq, + EmitLibdeviceMathCall("__nv_log", {sum_sq}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto angle, EmitLibdeviceMathCall("__nv_atan2", {b, a}, + {component_type, component_type}, + component_type)); + auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5); + return ComposeComplex(op, ir_builder_->CreateFMul(one_half, log_sum_sq), + angle); + } + // TODO(b/65408531): Implement kPower on GPU, where atan2 is available. + // case HloOpcode::kPower: + // // (a+bi)^(c+di) = exp(i(c+di)*arg(a+bi)) * (a*a+b*b)^(0.5(c+di)) + case HloOpcode::kExp: { + // e^(a+bi) = e^a*(cos(b)+sin(b)i) + auto b = imag(operand_value); + TF_ASSIGN_OR_RETURN( + auto exp_a, EmitLibdeviceMathCall("__nv_exp", {real(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_b, EmitLibdeviceMathCall("__nv_cos", {b}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_b, EmitLibdeviceMathCall("__nv_sin", {b}, {component_type}, + component_type)); + return ComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b), + ir_builder_->CreateFMul(exp_a, sin_b)); + } + case HloOpcode::kCos: { + // cos(a+bi) = .5(cos(a)*(e^-b+e^b) + i*sin(a)*(e^-b-e^b)) + auto a = real(operand_value); + auto llvm_ty = a->getType(); + TF_ASSIGN_OR_RETURN( + auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, + component_type)); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)), + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b))); + } + + case HloOpcode::kSin: { + // sin(a+bi) = 0.5(sin(a)*(e^b+e^-b) + i*cos(a)*(e^b-e^-b) + auto a = real(operand_value); + auto llvm_ty = a->getType(); + TF_ASSIGN_OR_RETURN( + auto exp_b, EmitLibdeviceMathCall("__nv_exp", {imag(operand_value)}, + {component_type}, component_type)); + TF_ASSIGN_OR_RETURN( + auto cos_a, EmitLibdeviceMathCall("__nv_cos", {a}, {component_type}, + component_type)); + TF_ASSIGN_OR_RETURN( + auto sin_a, EmitLibdeviceMathCall("__nv_sin", {a}, {component_type}, + component_type)); + auto half_exp_b = + ir_builder_->CreateFMul(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + auto half_exp_neg_b = + ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 0.5), exp_b); + return ComposeComplex( + op, + ir_builder_->CreateFMul( + sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)), + ir_builder_->CreateFMul( + cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b))); + } + default: + return ElementalIrEmitter::EmitComplexUnaryOp(op, operand_value); + } +} + llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( const string& callee_name, tensorflow::gtl::ArraySlice operands, @@ -235,13 +345,12 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( std::vector ir_input_types; for (PrimitiveType input_type : input_types) { ir_input_types.push_back( - llvm_ir::PrimitiveTypeToIrType(input_type, ir_builder_)); + llvm_ir::PrimitiveTypeToIrType(input_type, module_)); } llvm::FunctionType* callee_type = llvm::FunctionType::get( - llvm_ir::PrimitiveTypeToIrType(output_type, - ir_builder_), // The return type. - ir_input_types, // The parameter types. - false); // No variadic arguments. + llvm_ir::PrimitiveTypeToIrType(output_type, module_), // Return type. + ir_input_types, // Parameter types. + false); // No variadic arguments. // Declares the callee if it is not declared already. llvm::Function* callee = llvm::cast( @@ -315,7 +424,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), "reduce_window_accum_ptr", ir_builder_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, @@ -377,7 +486,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( const HloInstruction* operand = hlo->operand(0); llvm::Value* accum_ptr = ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - hlo->shape().element_type(), ir_builder())); + hlo->shape().element_type(), module_)); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))({})); ir_builder()->CreateStore(init_value, accum_ptr); diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h index 6ddfc3710c56a4e129f050f862812a3d78d8dba0..3defa1b696d3addc012702e23102bb1fa140170d 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.h @@ -54,6 +54,9 @@ class GpuElementalIrEmitter : public ElementalIrEmitter { StatusOr EmitFloatUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitComplexUnaryOp( + const HloInstruction* op, llvm::Value* operand_value) const override; + StatusOr EmitFloatBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value) const override; diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc index a9ef204b46facafabcf16d1d38d69c14e6aab497..c137fbc97e29e24edb3603c611a5c8f093bc62a6 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger.cc @@ -83,11 +83,11 @@ double CalculateBytesReadByFusionParameter(HloInstruction* param) { // Returns the bytes read by all fusion parameters of instruction 'fusion'. double CalculateBytesReadByFusionInstruction(HloInstruction* fusion) { double bytes = 0.0; - for (const auto& fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : fusion->fused_instructions()) { if (fused_instruction->opcode() != HloOpcode::kParameter) { continue; } - bytes += CalculateBytesReadByFusionParameter(fused_instruction.get()); + bytes += CalculateBytesReadByFusionParameter(fused_instruction); } return bytes; } @@ -238,7 +238,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { // re-use by the consumer), and so we honor that choice here as well. if (!std::all_of(fusion->fused_instructions().begin(), fusion->fused_instructions().end(), - [](const std::unique_ptr& instruction) { + [](const HloInstruction* instruction) { if (instruction->opcode() != HloOpcode::kParameter && GpuInstructionFusion::IsExpensive(*instruction)) { return false; @@ -293,14 +293,7 @@ Status FusionInstructionMerger::HandleFusion(HloInstruction* fusion) { StatusOr FusionMerger::Run(HloModule* module) { bool changed = false; VLOG(2) << "FusionMerger for module: " << module->name(); - std::vector computations; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - computations.push_back(computation.get()); - } - for (auto& computation : computations) { + for (auto* computation : module->MakeNonfusionComputations()) { VLOG(1) << "Before running FusionInstructionMerger for computation: " << computation->name(); XLA_VLOG_LINES(3, computation->ToString()); diff --git a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc index e68201417ba74b41978c839e31b16fb87080431e..deef5966b80d1b7f16e9982eed9ac5c7131e9d73 100644 --- a/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc +++ b/tensorflow/compiler/xla/service/gpu/fusion_merger_test.cc @@ -293,15 +293,15 @@ TEST_F(FusionMergerTest, MergeSharedFusionInstruction) { // Check operand 0 (not merged). Should have 4 instructions. auto* operand0 = root->operand(0); EXPECT_EQ(HloOpcode::kFusion, operand0->opcode()); - EXPECT_EQ(4, operand0->fused_instructions().size()); + EXPECT_EQ(4, operand0->fused_instruction_count()); // Check operand 1 (should have merged in its operand fusion instruction). auto* operand1 = root->operand(1); EXPECT_EQ(HloOpcode::kFusion, operand1->opcode()); - EXPECT_EQ(7, operand1->fused_instructions().size()); + EXPECT_EQ(7, operand1->fused_instruction_count()); // Check operand 2 (should have merged in its operand fusion instruction). auto* operand2 = root->operand(2); EXPECT_EQ(HloOpcode::kFusion, operand2->opcode()); - EXPECT_EQ(7, operand2->fused_instructions().size()); + EXPECT_EQ(7, operand2->fused_instruction_count()); } // Tests that we do not merge a fusion instruction that above flops to bytes diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index fee0fe30c6140b75d30c76e6a22a12f718afeaa8..2caa8f60517c66c1708e52481f01727f0008afd9 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -61,11 +61,14 @@ limitations under the License. #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" @@ -77,14 +80,13 @@ namespace se = ::perftools::gputools; namespace xla { namespace gpu { -namespace { +/* static */ const char* GpuCompiler::kTargetTriple = "nvptx64-nvidia-cuda"; +/* static */ const char* GpuCompiler::kDataLayout = + "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; -// The triple that represents our target. -const char* kTargetTriple = "nvptx64-nvidia-cuda"; +namespace { -// The data layout of the emitted module. Copied from computeDataLayout in -// NVPTXTargetMachine.cpp. -const char* kDataLayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; +using tensorflow::strings::StrCat; // Any address of a variable residing in global memory or returned by one of the // memory allocation routines from the driver or runtime API is always aligned @@ -93,15 +95,13 @@ const char* kDataLayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"; // http://docs.nvidia.com/cuda/cuda-c-programming-guide/#device-memory-accesses constexpr int64 kMemoryAlignment = 256; -// Returns the directory containing nvvm libdevice files. This function is -// called in GpuCompiler's constructor, so can't return an error. But -// GpuCompiler::Compile will return an error when the wanted libdevice file -// doesn't exist in the folder this function returns. -string GetLibdeviceDir(const HloModuleConfig& config) { +// Returns the directory containing nvvm libdevice files. config_cuda_data_dir +// should be equal to config().debug_options().xla_gpu_cuda_data_dir() of the +// HloModule being compiled. +string GetLibdeviceDir(const string& config_cuda_data_dir) { std::vector potential_libdevice_dirs; - const string datadir = config.debug_options().xla_gpu_cuda_data_dir(); - if (!datadir.empty()) { - potential_libdevice_dirs.push_back(datadir); + if (!config_cuda_data_dir.empty()) { + potential_libdevice_dirs.push_back(config_cuda_data_dir); } potential_libdevice_dirs.push_back(tensorflow::LibdeviceRoot()); @@ -149,6 +149,9 @@ tensorflow::Status OptimizeHloModule( pass.AddPass( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); + pass.AddPass(); + pass.AddPass(); + pass.AddPass(); pass.AddPass(); pass.AddPass(); } @@ -224,7 +227,7 @@ tensorflow::Status PrepareHloModuleForIrEmitting( } // Invokes the ptxas tool on the given PTX string, and dumps its output. -void DumpPtxasInfo(const string& ptx) { +void DumpPtxasInfo(const string& ptx, int cc_major, int cc_minor) { const string ptxas_path = tensorflow::io::JoinPath(tensorflow::CudaRoot(), "bin/ptxas"); // Do not log PTX stats if ptxas is not found at the given path. @@ -246,17 +249,22 @@ void DumpPtxasInfo(const string& ptx) { // Invoke ptxas and collect its output. tensorflow::SubProcess ptxas_info_dumper; - ptxas_info_dumper.SetProgram(ptxas_path, {ptxas_path, ptx_path, "-o", - "/dev/null", "-v", "-arch=sm_35"}); + ptxas_info_dumper.SetProgram(ptxas_path, + {ptxas_path, ptx_path, "-o", "/dev/null", "-v", + StrCat("-arch=sm_", cc_major, cc_minor)}); ptxas_info_dumper.SetChannelAction(tensorflow::CHAN_STDERR, tensorflow::ACTION_PIPE); - CHECK(ptxas_info_dumper.Start()); + if (!ptxas_info_dumper.Start()) { + LOG(ERROR) << "Failed to launch ptxas."; + return; + } string stderr_output; int exit_status = ptxas_info_dumper.Communicate( /*stdin_input=*/nullptr, /*stdout_output=*/nullptr, &stderr_output); XLA_LOG_LINES(tensorflow::INFO, stderr_output); if (exit_status != 0) { - LOG(FATAL) << "Invalid PTX. See the error message above for reasons."; + LOG(ERROR) << "ptxas exited with non-zero error code " << exit_status + << "."; } } @@ -311,12 +319,12 @@ StatusOr> GpuCompiler::Compile( // print one ourselves. XLA_VLOG_LINES(2, buffer_assignment->ToString()); - const string dump_debug_json_to = - module->config().debug_options().xla_dump_debug_json_to(); - if (!dump_debug_json_to.empty()) { + const string xla_dump_hlo_proto_to = + module->config().debug_options().xla_dump_hlo_proto_to(); + if (!xla_dump_hlo_proto_to.empty()) { HloProto proto = MakeHloProto(*module, *buffer_assignment); - TF_RETURN_IF_ERROR(protobuf_util::DumpJsonToDirectory( - proto, dump_debug_json_to, module->name())); + TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory( + proto, xla_dump_hlo_proto_to, module->name())); } IrEmitterContext ir_emitter_context(module.get(), buffer_assignment.get(), @@ -325,7 +333,6 @@ StatusOr> GpuCompiler::Compile( HloComputation* entry_computation = module->entry_computation(); IrEmitterUnnested ir_emitter(module->config(), entry_computation, - module->config().has_hybrid_result(), &ir_emitter_context); TF_RETURN_IF_ERROR( entry_computation->root_instruction()->Accept(&ir_emitter)); @@ -342,12 +349,36 @@ StatusOr> GpuCompiler::Compile( XLA_VLOG_LINES(2, ir_module_string_before_opt); } - // Reserve space for the PTX to be generated for this module. + const string& ir_dump_directory = + module->config().debug_options().xla_dump_ir_to(); + + if (!ir_dump_directory.empty()) { + TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( + /*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/module->name(), llvm_module, + /*optimized=*/false)); + } + string* ptx; + string libdevice_dir; { tensorflow::mutex_lock lock(mutex_); + + // Reserve space for the PTX to be generated for this module. generated_ptxes_.emplace_back(MakeUnique()); ptx = generated_ptxes_.back().get(); + + // Find the directory containing libdevice. To avoid searching for it every + // time, we have a one-element cache, keyed on the module's config's + // cuda_data_dir. + const auto& config_cuda_data_dir = + module->config().debug_options().xla_gpu_cuda_data_dir(); + if (cached_libdevice_dir_.empty() || + cached_cuda_data_dir_ != config_cuda_data_dir) { + cached_cuda_data_dir_ = config_cuda_data_dir; + cached_libdevice_dir_ = GetLibdeviceDir(config_cuda_data_dir); + } + libdevice_dir = cached_libdevice_dir_; } int cc_major, cc_minor; if (!stream_exec->GetDeviceDescription().cuda_compute_capability(&cc_major, @@ -357,12 +388,16 @@ StatusOr> GpuCompiler::Compile( cc_major = 2; cc_minor = 0; } - if (libdevice_dir_.empty()) { - // Compute libdevice_dir_ just once and cache it in this member. - libdevice_dir_ = GetLibdeviceDir(module->config()); - } + TF_ASSIGN_OR_RETURN(*ptx, CompileToPtx(&llvm_module, {cc_major, cc_minor}, - module->config(), libdevice_dir_)); + module->config(), libdevice_dir)); + + if (!ir_dump_directory.empty()) { + TF_RETURN_IF_ERROR(llvm_ir::DumpIRToDirectory( + /*directory_name=*/ir_dump_directory, + /*hlo_module_name=*/module->name(), llvm_module, + /*optimized=*/true)); + } if (user_post_optimization_hook_) { TF_CHECK_OK(user_post_optimization_hook_(llvm_module)); @@ -372,7 +407,7 @@ StatusOr> GpuCompiler::Compile( VLOG(2) << "PTX:"; XLA_VLOG_LINES(2, *ptx); if (VLOG_IS_ON(2)) { - DumpPtxasInfo(*ptx); + DumpPtxasInfo(*ptx, cc_major, cc_minor); } auto thunk_schedule = MakeUnique( @@ -393,7 +428,7 @@ StatusOr> GpuCompiler::Compile( StatusOr>> GpuCompiler::Compile( std::vector> modules, - std::vector stream_execs) { + std::vector> stream_execs) { return Unimplemented( "Compilation of multiple HLO modules is not yet supported on GPU."); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h index e8073935990938c9ecf0d835066c4c490c7cc2c4..7a4c4b00d9ad0d895d6b326d2e58f3becdac56d0 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.h @@ -46,7 +46,8 @@ class GpuCompiler : public LLVMCompiler { StatusOr>> Compile( std::vector> modules, - std::vector stream_exec) override; + std::vector> + stream_execs) override; StatusOr>> CompileAheadOfTime(std::vector> module, @@ -62,19 +63,34 @@ class GpuCompiler : public LLVMCompiler { }; } + // The triple that represents our target. + static const char* kTargetTriple; + + // The data layout of the emitted module. Copied from computeDataLayout in + // NVPTXTargetMachine.cpp. + static const char* kDataLayout; + private: - // The parent directory of libdevice IR libraries. - string libdevice_dir_; + // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. + const int64 pointer_size_; + + tensorflow::mutex mutex_; + + // When compiling an HLO module, we need to find a path to the nvvm libdevice + // files. We search in the module's config.debug_options().cuda_data_dir() + // and in tensorflow::LibdeviceRoot(), the latter of which is a constant. + // + // We cache the cuda_data_dir() and the result of our search, so that if the + // next module we have to compile has the same cuda_data_dir(), we can skip + // the search. + string cached_cuda_data_dir_ GUARDED_BY(mutex_); + string cached_libdevice_dir_ GUARDED_BY(mutex_); // The list of PTX strings generated by this GpuCompiler. We let GpuCompiler // to own them because they need to be alive across the life span of the // StreamExecutor (b/24776264). - tensorflow::mutex mutex_; std::vector> generated_ptxes_ GUARDED_BY(mutex_); - // The size in bytes of a pointer. Used by ShapeSizeBytesFunction. - int64 pointer_size_; - TF_DISALLOW_COPY_AND_ASSIGN(GpuCompiler); }; diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc index db7f9826d798181b55b5fd6cef4ea749d4fe7d53..254d0d770560b32298533f04139ab2f6c9a167ce 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc @@ -88,7 +88,7 @@ class HloExecutionProfiler { if (do_profile_) { stream_->ThenStopTimer(per_op_timer_.get()); stream_->BlockHostUntilDone(); - profile_->AddProfileResult( + profile_->SetCyclesTakenBy( hlo_instruction, per_op_timer_->Nanoseconds() * clock_rate_ghz_); } } @@ -108,9 +108,10 @@ class HloExecutionProfiler { // Implementation note: HLO profiling is always enabled for GPU executables, // since we can use timers around thunks. GpuExecutable::GpuExecutable( - tensorflow::StringPiece ptx, std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, - std::unique_ptr assignment, + tensorflow::StringPiece ptx, + std::unique_ptr thunk_schedule, + std::unique_ptr hlo_module, + std::unique_ptr assignment, HloCostAnalysis::ShapeSizeFunction shape_size_function) : Executable(std::move(hlo_module)), ptx_(ptx), @@ -183,9 +184,6 @@ StatusOr GpuExecutable::ExecuteOnStream( HloExecutionProfile* hlo_execution_profile) { se::Stream* stream = run_options->stream(); DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // This ExecuteOnStream overload should only be called if has_hybrid_result is - // false. - TF_RET_CHECK(!module_config().has_hybrid_result()); BufferAllocations::Builder buffer_allocations_builder; for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size(); @@ -263,9 +261,6 @@ StatusOr> GpuExecutable::ExecuteOnStream( tensorflow::gtl::ArraySlice arguments, HloExecutionProfile* hlo_execution_profile) { DeviceMemoryAllocator* memory_allocator = run_options->allocator(); - // This ExecuteOnStream overload should only be called by the LocalService - // which sets has_hybrid_result to true. - TF_RET_CHECK(module_config().has_hybrid_result()); if (GetRootPointsToSet().IsAmbiguous()) { return Unimplemented("Points-to set of root instruction is ambiguous"); @@ -277,9 +272,6 @@ StatusOr> GpuExecutable::ExecuteOnStream( const BufferAllocation& allocation = assignment_->GetAllocation(i); if (allocation.is_entry_computation_parameter()) { auto param_no = allocation.parameter_number(); - if (ShapeUtil::IsTuple(arguments[param_no]->shape())) { - return Unimplemented("Tuple ShapedBuffer arguments not supported"); - } buffer_allocations_builder.RegisterBuffer( i, arguments[param_no]->buffer(/*index=*/{})); } @@ -298,9 +290,8 @@ StatusOr> GpuExecutable::ExecuteOnStream( HloInstruction* root = hlo_module_->entry_computation()->root_instruction(); auto device_ordinal = executor->device_ordinal(); - TF_ASSIGN_OR_RETURN(auto shaped_buffer, - ShapedBuffer::MakeShapedBuffer( - root->shape(), executor->platform(), device_ordinal)); + auto shaped_buffer = MakeUnique( + root->shape(), executor->platform(), device_ordinal); // Copy DeviceMemoryBase values which contain the array(s) of the result into // the respective location in ShapedBuffer. @@ -310,32 +301,29 @@ StatusOr> GpuExecutable::ExecuteOnStream( ->ForEachMutableElementWithStatus( [&buffer_allocations, &buffers_in_result, &shaped_buffer, this]( const ShapeIndex& index, size_t* buffer_entry) { - if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { - const auto& sources = - this->GetRootPointsToSet().element(index); - // The points to set is unambiguous so the set should be a - // singleton. That is, we know exactly which instruction - // produced the array at this element. - CHECK_EQ(1, sources.size()); - auto src_hlo = sources[0]->instruction(); - - VLOG(4) << "Looking at: " << sources[0]; - - // The source instruction should have a non-parameter buffer - // assigned. - TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, - this->assignment_->GetUniqueSlice( - src_hlo, sources[0]->index())); - CHECK(!slice.allocation()->is_entry_computation_parameter()); - - perftools::gputools::DeviceMemoryBase src_base = - buffer_allocations->GetDeviceAddress(slice.index()); - CHECK(!src_base.is_null() || src_base.size() == 0); - shaped_buffer->mutable_buffers()->push_back(src_base); - *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; - - buffers_in_result.insert(src_base); - } + const auto& sources = this->GetRootPointsToSet().element(index); + // The points-to set is unambiguous so the set should be a + // singleton. That is, we know exactly which instruction + // produced the array at this element. + CHECK_EQ(1, sources.size()); + auto src_hlo = sources[0]->instruction(); + + VLOG(4) << "Looking at: " << sources[0]; + + // The source instruction should have a non-parameter buffer + // assigned. + TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice slice, + this->assignment_->GetUniqueSlice( + src_hlo, sources[0]->index())); + CHECK(!slice.allocation()->is_entry_computation_parameter()); + + perftools::gputools::DeviceMemoryBase src_base = + buffer_allocations->GetDeviceAddress(slice.index()); + CHECK(!src_base.is_null() || src_base.size() == 0); + shaped_buffer->mutable_buffers()->push_back(src_base); + *buffer_entry = shaped_buffer->mutable_buffers()->size() - 1; + + buffers_in_result.insert(src_base); return Status::OK(); })); TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h index bbf8549fdbcd1017c95b2a6485319f72e91df5c5..748a8f521bc5293d58de19ab52f4bdecec6cb1e5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h @@ -48,9 +48,9 @@ namespace gpu { class GpuExecutable : public Executable { public: GpuExecutable(tensorflow::StringPiece ptx, - std::unique_ptr thunk_schedule, - std::unique_ptr hlo_module, - std::unique_ptr assignment, + std::unique_ptr thunk_schedule, + std::unique_ptr hlo_module, + std::unique_ptr assignment, HloCostAnalysis::ShapeSizeFunction shape_size_function); // This should be called after set_ir_module_string. @@ -115,14 +115,14 @@ class GpuExecutable : public Executable { // The thunks to be invoked by this GpuExecutable. They are generated by the // IrEmitter. - const std::unique_ptr thunk_schedule_; + const std::unique_ptr thunk_schedule_; // Owns the buffer data at runtime. It provides information to allocate // memory for every output/temp buffers. - const std::unique_ptr assignment_; + const std::unique_ptr assignment_; // Function to compute the size of a given Shape, in bytes. - HloCostAnalysis::ShapeSizeFunction shape_size_function_; + const HloCostAnalysis::ShapeSizeFunction shape_size_function_; TF_DISALLOW_COPY_AND_ASSIGN(GpuExecutable); }; diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc similarity index 94% rename from tensorflow/compiler/xla/service/gpu_transfer_manager.cc rename to tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc index 74f0bdb7db1847119c5bd75cc9fd9d921c6e162a..f0f036f7f381db15b84db85d3efeec5d8141884e 100644 --- a/tensorflow/compiler/xla/service/gpu_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/gpu_transfer_manager.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h" #include #include #include +#include "llvm/IR/DataLayout.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/gpu/gpu_compiler.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -39,7 +41,10 @@ namespace xla { // folding back the cpu and gpu infeed implementations into a generic // one if possible. GpuTransferManager::GpuTransferManager() - : GenericTransferManager(se::cuda::kCudaPlatformId) {} + : GenericTransferManager( + se::cuda::kCudaPlatformId, + /*pointer_size=*/llvm::DataLayout(gpu::GpuCompiler::kDataLayout) + .getPointerSize()) {} Status GpuTransferManager::TransferLiteralToInfeed(se::StreamExecutor* executor, const Literal& literal) { diff --git a/tensorflow/compiler/xla/service/gpu_transfer_manager.h b/tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h similarity index 100% rename from tensorflow/compiler/xla/service/gpu_transfer_manager.h rename to tensorflow/compiler/xla/service/gpu/gpu_transfer_manager.h diff --git a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc index 81e905a06665436875b17991a8635e7bb31600de..42c1539e86c2ab162fa473852b80b28b57d0e370 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_schedule.cc @@ -119,11 +119,10 @@ GpuHloOrdering::GpuHloOrdering( // postorder, so we can do better and establish the total order here. We don't // do that yet since it's hard to ensure that the order here is the order used // by IrEmitterNested. And mismatched ordering bugs would be hard to find. - for (auto& computation : module->computations()) { - if (computation.get() != module->entry_computation() && + for (auto* computation : module->computations()) { + if (computation != module->entry_computation() && !computation->IsFusionComputation()) { - predecessors_.emplace(computation.get(), - computation->ComputeReachability()); + predecessors_.emplace(computation, computation->ComputeReachability()); } } } @@ -160,9 +159,9 @@ void BFSLaunchOrder(const HloComputation* computation, std::unordered_map incoming_edge_count; for (const auto& hlo : computation->instructions()) { if (hlo->operand_count() == 0) { - queue.push_back(hlo.get()); + queue.push_back(hlo); } else { - incoming_edge_count[hlo.get()] = + incoming_edge_count[hlo] = std::set(hlo->operands().begin(), hlo->operands().end()) .size(); diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc index 373c1aa5f9582fdb5f03f17f8a90a5e640f7b54d..163a161353fdb90cee2968269d572b8414855551 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" @@ -67,7 +67,7 @@ void HloToIrBindings::EmitBasePointersForHlos( // Lookup allocation GetTupleElement operand. const BufferAllocation::Slice slice = buffer_assignment_ - ->GetUniqueTopLevelSlice(LatestNonGteAncestor(non_io_hlo)) + ->GetUniqueTopLevelSlice(non_io_hlo->LatestNonGteAncestor()) .ConsumeValueOrDie(); // We are not in a nested context, so check non-thread-local allocation. CHECK(!slice.allocation()->is_thread_local()); @@ -102,7 +102,7 @@ void HloToIrBindings::EmitBasePointersForHlos( slice_result.ConsumeValueOrDie(); if (slice.allocation()->is_thread_local()) { llvm::Type* pointee_type = - llvm_ir::ShapeToIrType(non_io_hlo->shape(), ir_builder_); + llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_); BindHloToIrValue(*non_io_hlo, ir_builder_->CreateAlloca(pointee_type), index); } else { @@ -124,18 +124,18 @@ llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte, if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) { return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_); + GetTypedIrValue(*gte->operand(0), {}, base_ptr), ir_builder_, module_); } return llvm_ir::EmitGetTupleElement( gte->shape(), gte->tuple_index(), /*alignment=*/1, - EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_); + EmitGetTupleElement(gte->operand(0), base_ptr), ir_builder_, module_); } llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo, const ShapeIndex& shape_index, llvm::Value* ir_value) { llvm::Type* pointee_type = llvm_ir::ShapeToIrType( - ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_builder_); + ShapeUtil::GetSubshape(hlo.shape(), shape_index), module_); llvm::Type* dest_type = pointee_type->getPointerTo(); llvm::Value* typed_ir_value; diff --git a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h index d43e09e8a8c5cc2efcd8e1fbf9a7c0697e24d73c..a3120f15bcbfb0f2f0bfbd806e7a4ff05316d5dd 100644 --- a/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h +++ b/tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h @@ -36,10 +36,12 @@ class HloToIrBindings { public: HloToIrBindings(const HloModule& module, const BufferAssignment* buffer_assignment, - llvm::IRBuilder<>* ir_builder, bool is_nested) + llvm::IRBuilder<>* ir_builder, llvm::Module* llvm_module, + bool is_nested) : buffer_assignment_(buffer_assignment), is_nested_(is_nested), ir_builder_(ir_builder), + module_(llvm_module), alias_analysis_(module, *buffer_assignment_, &ir_builder_->getContext()) {} @@ -93,6 +95,7 @@ class HloToIrBindings { const bool is_nested_; llvm::IRBuilder<>* ir_builder_; + llvm::Module* module_; // Stores the underlying llvm::IrArray for each HloInstruction. // For an instruction that generates multiple outputs, the root will be a diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc index 0b94594f1dc5cd040846eabaad01b4cd09520e12..9a4bfd0905bb62c02c70e7f2eea46872c07bca89 100644 --- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc +++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc @@ -152,8 +152,10 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfConvolutionUnfused) { conv_window_col->set_padding_high(1); ConvolutionDimensionNumbers conv_dnums; - conv_dnums.set_batch_dimension(0); - conv_dnums.set_feature_dimension(1); + conv_dnums.set_input_batch_dimension(0); + conv_dnums.set_output_batch_dimension(0); + conv_dnums.set_input_feature_dimension(1); + conv_dnums.set_output_feature_dimension(1); conv_dnums.add_spatial_dimensions(2); conv_dnums.add_spatial_dimensions(3); conv_dnums.set_kernel_output_feature_dimension(0); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6be26dde8f957040c73db6a7e52f050e44d44c06..8fb7a6adda9dc7c36eb9aabcbcdc9d77e6c22c4a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -214,12 +214,5 @@ llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, value->getType()); } -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo) { - while (hlo->opcode() == HloOpcode::kGetTupleElement) { - hlo = hlo->operand(0); - } - return hlo; -} - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h index 422972762ee3da793852429a71b4cee76e41e2bc..06c3205296e4546e39525ec093cc17e2fc375d0d 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.h @@ -53,10 +53,6 @@ llvm::Value* EmitPrintf(tensorflow::StringPiece fmt, llvm::Value* EmitShuffleDown(llvm::Value* value, llvm::Value* offset, llvm::IRBuilder<>* builder); -// Resolves GetTupleElement instruction operands starting with 'hlo'. -// Returns the first ancestor instruction which is not a GetTupleElement. -const HloInstruction* LatestNonGteAncestor(const HloInstruction* hlo); - } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index a76d217cac271bcda950c1c325f67810dd513383..57a3f713e35b506ad9d5caab1ced2c7b74f8efcf 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -34,7 +34,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -53,9 +53,10 @@ namespace gpu { IrEmitter::IrEmitter(const HloModuleConfig& hlo_module_config, IrEmitterContext* ir_emitter_context, bool is_nested) : ir_emitter_context_(ir_emitter_context), - ir_builder_(ir_emitter_context->llvm_module()->getContext()), + module_(ir_emitter_context->llvm_module()), + ir_builder_(module_->getContext()), bindings_(ir_emitter_context->hlo_module(), - &ir_emitter_context->buffer_assignment(), &ir_builder_, + &ir_emitter_context->buffer_assignment(), &ir_builder_, module_, is_nested), hlo_module_config_(hlo_module_config) { ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags( @@ -71,18 +72,17 @@ Status IrEmitter::DefaultAction(HloInstruction* hlo) { }; } return EmitTargetElementLoop( - *hlo, GpuElementalIrEmitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()) + *hlo, GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + GetNestedComputer()) .MakeElementGenerator(hlo, operand_to_generator)); } -Status IrEmitter::HandleConstant(HloInstruction* constant, - const Literal& literal) { +Status IrEmitter::HandleConstant(HloInstruction* constant) { + const Literal& literal = constant->literal(); llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_); + llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable( - *ir_emitter_context_->llvm_module(), initializer->getType(), + *module_, initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::PrivateLinkage, initializer, /*Name=*/""); VLOG(2) << "HandleConstant: " << constant->ToString() << std::endl @@ -106,8 +106,8 @@ Status IrEmitter::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) { +Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element) { + auto operand = get_tuple_element->operand(0); CHECK(bindings_.BoundToIrValue(*operand)); bindings_.BindHloToIrValue( *get_tuple_element, @@ -115,32 +115,29 @@ Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, get_tuple_element->shape(), get_tuple_element->tuple_index(), // TODO(b/26344050): tighten the alignment here // based on the real element type. - /*alignment=*/1, GetBasePointer(*operand), &ir_builder_)); + /*alignment=*/1, GetBasePointer(*operand), &ir_builder_, module_)); return Status::OK(); } -Status IrEmitter::HandleSort(HloInstruction* sort, - HloInstruction* operand_instruction) { +Status IrEmitter::HandleSort(HloInstruction*) { // TODO(b/26783907): Implement sort on GPU. return Unimplemented("sort"); } -Status IrEmitter::HandleSend(HloInstruction* send) { +Status IrEmitter::HandleSend(HloInstruction*) { return Unimplemented("Send is not implemented on GPU"); } -Status IrEmitter::HandleRecv(HloInstruction* recv) { +Status IrEmitter::HandleRecv(HloInstruction*) { return Unimplemented("Recv is not implemented on GPU"); } -Status IrEmitter::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status IrEmitter::HandleTuple(HloInstruction* tuple) { std::vector base_ptrs; - for (const HloInstruction* operand : operands) { + for (const HloInstruction* operand : tuple->operands()) { base_ptrs.push_back(GetBasePointer(*operand)); } - llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(*tuple), base_ptrs, &ir_builder_, module_); return Status::OK(); } @@ -321,15 +318,16 @@ Status IrEmitter::EmitAtomicOperationForNestedComputation( return Status::OK(); } -Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) { +Status IrEmitter::HandleSelect(HloInstruction* select) { + auto pred = select->operand(0); + auto on_true = select->operand(1); + auto on_false = select->operand(2); TF_RET_CHECK(pred->shape().element_type() == PRED); if (ShapeUtil::IsTuple(select->shape())) { llvm_ir::EmitTupleSelect(GetIrArray(*select), GetIrArray(*pred), GetBasePointer(*on_true), - GetBasePointer(*on_false), &ir_builder_); + GetBasePointer(*on_false), &ir_builder_, module_); return Status::OK(); } @@ -339,9 +337,9 @@ Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred, return IrEmitter::DefaultAction(select); } -Status IrEmitter::HandleDot(HloInstruction* dot, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction) { +Status IrEmitter::HandleDot(HloInstruction* dot) { + auto lhs_instruction = dot->operand(0); + auto rhs_instruction = dot->operand(1); const llvm_ir::IrArray& target_array = GetIrArray(*dot); const llvm_ir::IrArray& lhs_array = GetIrArray(*lhs_instruction); const llvm_ir::IrArray& rhs_array = GetIrArray(*rhs_instruction); @@ -355,7 +353,26 @@ Status IrEmitter::HandleDot(HloInstruction* dot, lhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); llvm::Value* rhs_value = rhs_array.EmitReadArrayElement(/*index=*/{}, &ir_builder_); - llvm::Value* result = ir_builder_.CreateFMul(lhs_value, rhs_value); + llvm::Value* result; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { + auto real = [&](llvm::Value* x) { + return ir_builder_.CreateExtractValue(x, {0}); + }; + auto imag = [&](llvm::Value* x) { + return ir_builder_.CreateExtractValue(x, {1}); + }; + llvm::Value* real_result = ir_builder_.CreateFSub( + ir_builder_.CreateFMul(real(lhs_value), real(rhs_value)), + ir_builder_.CreateFMul(imag(lhs_value), imag(rhs_value))); + llvm::Value* imag_result = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(real(lhs_value), imag(rhs_value)), + ir_builder_.CreateFMul(imag(lhs_value), real(rhs_value))); + result = llvm::ConstantAggregateZero::get(lhs_array.GetElementLlvmType()); + result = ir_builder_.CreateInsertValue(result, real_result, {0}); + result = ir_builder_.CreateInsertValue(result, imag_result, {1}); + } else { + result = ir_builder_.CreateFMul(lhs_value, rhs_value); + } target_array.EmitWriteArrayElement(/*index=*/{}, result, &ir_builder_); return Status::OK(); } @@ -411,8 +428,8 @@ Status IrEmitter::HandleDot(HloInstruction* dot, // Initialize the accumulator in the preheader to zero. new llvm::StoreInst( - llvm::ConstantFP::get(accum_type, 0.0), // The value stored. - accum_address, // The address. + llvm::Constant::getNullValue(lhs_array.GetElementLlvmType()), // init 0 + accum_address, // The address. reduction_loop->GetPreheaderBasicBlock() ->getTerminator()); // The instruction this store is inserted before. @@ -427,9 +444,27 @@ Status IrEmitter::HandleDot(HloInstruction* dot, lhs_array.EmitReadArrayElement(lhs_index, &ir_builder_); llvm::Value* rhs_element = rhs_array.EmitReadArrayElement(rhs_index, &ir_builder_); - llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); llvm::Value* accum = ir_builder_.CreateLoad(accum_address); - llvm::Value* updated_accum = ir_builder_.CreateFAdd(accum, product); + llvm::Value* updated_accum; + if (ShapeUtil::ElementIsComplex(lhs_shape)) { +#define REAL(x) ir_builder_.CreateExtractValue(x, {0}) +#define IMAG(x) ir_builder_.CreateExtractValue(x, {1}) + llvm::Value* product_real = ir_builder_.CreateFSub( + ir_builder_.CreateFMul(REAL(lhs_element), REAL(rhs_element)), + ir_builder_.CreateFMul(IMAG(lhs_element), IMAG(rhs_element))); + llvm::Value* product_imag = ir_builder_.CreateFAdd( + ir_builder_.CreateFMul(REAL(lhs_element), IMAG(rhs_element)), + ir_builder_.CreateFMul(IMAG(lhs_element), REAL(rhs_element))); + updated_accum = ir_builder_.CreateInsertValue( + accum, ir_builder_.CreateFAdd(REAL(accum), product_real), {0}); + updated_accum = ir_builder_.CreateInsertValue( + updated_accum, ir_builder_.CreateFAdd(IMAG(accum), product_imag), {1}); +#undef IMAG +#undef REAL + } else { + llvm::Value* product = ir_builder_.CreateFMul(lhs_element, rhs_element); + updated_accum = ir_builder_.CreateFAdd(accum, product); + } ir_builder_.CreateStore(updated_accum, accum_address); // After the reduction loop exits, store the accumulator into the target @@ -461,10 +496,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot, return Status::OK(); } -Status IrEmitter::HandleConvolution(HloInstruction* convolution, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction, - const Window& window) { +Status IrEmitter::HandleConvolution(HloInstruction* convolution) { if (ShapeUtil::HasZeroElements(convolution->shape())) { // Emit no code for an empty output. return Status::OK(); @@ -484,17 +516,18 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status IrEmitter::HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) { +Status IrEmitter::HandleReduce(HloInstruction* reduce) { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); return EmitTargetElementLoop( *reduce, [=](const llvm_ir::IrArray::Index& index) -> StatusOr { // Initialize an accumulator with init_value. llvm::AllocaInst* accumulator_addr = ir_builder_.CreateAlloca(llvm_ir::PrimitiveTypeToIrType( - reduce->shape().element_type(), &ir_builder_)); + reduce->shape().element_type(), module_)); ir_builder_.CreateStore( ir_builder_.CreateLoad(GetBasePointer(*init_value)), accumulator_addr); @@ -547,8 +580,7 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) { for (HloInstruction* operand : fusion->operands()) { parameter_arrays.push_back(GetIrArray(*operand)); } - GpuElementalIrEmitter elemental_emitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), + GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &ir_builder_, GetNestedComputer()); FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); TF_RETURN_IF_ERROR(fusion->fused_expression_root()->Accept(&fused_emitter)); @@ -565,23 +597,19 @@ Status IrEmitter::HandleCall(HloInstruction* call) { GetBasePointer(*call)); } -Status IrEmitter::HandleCustomCall( - HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { +Status IrEmitter::HandleCustomCall(HloInstruction*) { return Unimplemented("custom-call"); } -Status IrEmitter::HandleInfeed(HloInstruction* infeed) { +Status IrEmitter::HandleInfeed(HloInstruction*) { return Unimplemented("Infeed is not supported on GPU (b/30467474)."); } -Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) { +Status IrEmitter::HandleOutfeed(HloInstruction*) { return Unimplemented("Outfeed is not supported on GPU (b/34359662)."); } -Status IrEmitter::HandleRng(HloInstruction* random, - RandomDistribution /*distribution*/) { +Status IrEmitter::HandleRng(HloInstruction* random) { ElementalIrEmitter::HloToElementGeneratorMap operand_to_generator; for (const HloInstruction* operand : random->operands()) { operand_to_generator[operand] = [=](const llvm_ir::IrArray::Index& index) { @@ -591,9 +619,8 @@ Status IrEmitter::HandleRng(HloInstruction* random, // Emits a single-threaded loop because the loop body generated by the element // generator for Rng can't be parallelized (b/32333178). return llvm_ir::LoopEmitter( - GpuElementalIrEmitter(hlo_module_config_, - ir_emitter_context_->llvm_module(), - &ir_builder_, GetNestedComputer()) + GpuElementalIrEmitter(hlo_module_config_, module_, &ir_builder_, + GetNestedComputer()) .MakeElementGenerator(random, operand_to_generator), GetIrArray(*random), &ir_builder_) .EmitLoop(IrName(random)); @@ -634,7 +661,7 @@ StatusOr IrEmitter::ComputeNestedElement( tensorflow::gtl::ArraySlice parameter_elements) { llvm::Value* return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType( - computation.root_instruction()->shape().element_type(), &ir_builder_), + computation.root_instruction()->shape().element_type(), module_), "return_buffer", &ir_builder_); std::vector parameter_buffers; for (llvm::Value* parameter_element : parameter_elements) { diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.h b/tensorflow/compiler/xla/service/gpu/ir_emitter.h index 2f6b3514497bff386d9f3e6f0d6c9737e8da4871..263992d92544166c0d08a6c60b43e78f10f06aed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.h @@ -74,39 +74,25 @@ class IrEmitter : public DfsHloVisitorWithDefault { // The following methods implement the DfsHloVisitorWithDefault interface. Status DefaultAction(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; + Status HandleConstant(HloInstruction* constant) override; Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override; - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleDot(HloInstruction* dot) override; + Status HandleConvolution(HloInstruction* convolution) override; Status HandleCrossReplicaSum(HloInstruction* crs) override; Status HandleInfeed(HloInstruction* infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; + Status HandleSort(HloInstruction* sort) override; Status HandleSend(HloInstruction* send) override; Status HandleRecv(HloInstruction* recv) override; Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; + Status HandleReduce(HloInstruction* reduce) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleSelect(HloInstruction* select) override; Status HandleFusion(HloInstruction* fusion) override; Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) override; - Status HandleRng(HloInstruction* random, - RandomDistribution /*distribution*/) override; + Status HandleCustomCall(HloInstruction* custom_call) override; + Status HandleRng(HloInstruction* random) override; Status FinishVisit(HloInstruction* root) override { return Status::OK(); } @@ -162,6 +148,7 @@ class IrEmitter : public DfsHloVisitorWithDefault { } IrEmitterContext* ir_emitter_context_; + llvm::Module* module_; // The following fields track the IR emission state. According to LLVM memory // management rules, their memory is owned by the module. @@ -218,7 +205,6 @@ class IrEmitterUnnested : public IrEmitter { public: IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, - bool has_hybrid_result, IrEmitterContext* ir_emitter_context); IrEmitterUnnested(const IrEmitterUnnested&) = delete; IrEmitterUnnested& operator=(const IrEmitterUnnested&) = delete; @@ -233,28 +219,17 @@ class IrEmitterUnnested : public IrEmitter { // IrEmitterUnnested handles the following instructions differently from // IrEmitter. Status HandleCopy(HloInstruction* copy) override; - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override; - Status HandleDot(HloInstruction* dot, HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction) override; + Status HandleConvolution(HloInstruction* convolution) override; + Status HandleDot(HloInstruction* dot) override; Status HandleFusion(HloInstruction* fusion) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; + Status HandleReduce(HloInstruction* reduce) override; Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; + Status HandleTuple(HloInstruction* tuple) override; Status HandleWhile(HloInstruction* xla_while) override; Status HandleInfeed(HloInstruction* xla_infeed) override; - Status HandleRng(HloInstruction* random, - RandomDistribution distribution) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; + Status HandleRng(HloInstruction* random) override; + Status HandleSelect(HloInstruction* select) override; Status EmitTargetElementLoop( const HloInstruction& hlo, @@ -340,8 +315,12 @@ class IrEmitterUnnested : public IrEmitter { // to make sure `inst` outlives the lifetime of the returned Thunk object. std::unique_ptr BuildGemmThunk(const HloInstruction* inst); - // Returns a CopyThunk that calls host-to-device cuMemcpy to implement `inst`. - std::unique_ptr BuildCopyThunk(const HloInstruction* inst); + // Returns a thunk that calls host-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildHostToDeviceCopyThunk(const HloInstruction* inst); + + // Returns a thunk that calls device-to-device cuMemcpy to implement `inst`. + std::unique_ptr BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst); // Returns an InfeedThunk that performs device-to-device memcpy to implement // `inst`. @@ -366,10 +345,6 @@ class IrEmitterUnnested : public IrEmitter { // The HloComputation that this IrEmitter emits code for. const HloComputation* hlo_computation_; - - // Whether this computation will produce a hybrid result, that is the - // computation produces a ShapedBuffer. - bool has_hybrid_result_; }; // Emits LLVM IR for a nested computation to the resultant function. diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc index 7e831e75d73bb30eafba473ae003d02f28fb6ec1..5da1a130d5654b86803396b07a6501c59a182c67 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_nested.cc @@ -52,9 +52,9 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( io_hlos->push_back(param); const Shape& param_shape = param->shape(); argument_types.push_back( - llvm_ir::ShapeToIrType(param_shape, &ir_builder_)->getPointerTo()); - int64 param_size = llvm_ir::ByteSizeOf( - param_shape, ir_emitter_context_->llvm_module()->getDataLayout()); + llvm_ir::ShapeToIrType(param_shape, module_)->getPointerTo()); + int64 param_size = + llvm_ir::ByteSizeOf(param_shape, module_->getDataLayout()); argument_dereferenceable_bytes.push_back(param_size); } { @@ -62,7 +62,7 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( io_hlos->push_back(root); const Shape& root_shape = root->shape(); argument_types.push_back( - llvm_ir::ShapeToIrType(root_shape, &ir_builder_)->getPointerTo()); + llvm_ir::ShapeToIrType(root_shape, module_)->getPointerTo()); int64 root_size = llvm_ir::ByteSizeOf( root_shape, ir_emitter_context_->llvm_module()->getDataLayout()); argument_dereferenceable_bytes.push_back(root_size); @@ -98,10 +98,10 @@ llvm::Function* IrEmitterNested::EmitBasePointersForNestedComputation( llvm::ReturnInst::Create(function->getContext(), entry_bb)); std::vector non_io_hlos; - for (const auto& hlo : nested_computation.instructions()) { + for (const auto* hlo : nested_computation.instructions()) { if (hlo->opcode() != HloOpcode::kParameter && - hlo.get() != nested_computation.root_instruction()) { - non_io_hlos.push_back(hlo.get()); + hlo != nested_computation.root_instruction()) { + non_io_hlos.push_back(hlo); } } bindings_.EmitBasePointersForHlos(*io_hlos, non_io_hlos); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 958408e875fad904caccfd993e625d1c7b365fc5..7b4662fc80c5518135c827489a3724e477b2bad1 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -51,6 +51,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" @@ -132,11 +133,9 @@ void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk, IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config, const HloComputation* hlo_computation, - bool has_hybrid_result, IrEmitterContext* ir_emitter_context) : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false), - hlo_computation_(hlo_computation), - has_hybrid_result_(has_hybrid_result) { + hlo_computation_(hlo_computation) { // Initialize thunk_sequence_ to an empty list of thunks. thunk_sequence_.reset(new ThunkSequence()); } @@ -147,7 +146,7 @@ Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) { } namespace { -bool ImplementedAsMemcpy(const HloInstruction& hlo) { +bool ImplementedAsHostToDeviceMemcpy(const HloInstruction& hlo) { // `hlo` needs to satisfy three conditions to be implemented as a // host-to-device cuMemcpy. // @@ -158,6 +157,20 @@ bool ImplementedAsMemcpy(const HloInstruction& hlo) { hlo.operand(0)->opcode() == HloOpcode::kConstant && ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()); } + +bool ImplementedAsDeviceToDeviceMemcpy( + const BufferAssignment& buffer_assignment, const HloInstruction& hlo) { + // `hlo` needs to satisfy three conditions to be implemented as a + // device-to-device cuMemcpy. + // + // 1. `hlo` is a kCopy instruction. + // 2. `hlo` and its operand have the same shape (thus the same layout too). + // 3. The operand to `hlo` has a buffer assignment (constants do not, for + // instance) which means the source buffer also resides on the device. + return hlo.opcode() == HloOpcode::kCopy && + ShapeUtil::Equal(hlo.operand(0)->shape(), hlo.shape()) && + buffer_assignment.HasTopLevelAllocation(hlo.operand(0)); +} } // namespace llvm::Function* IrEmitterUnnested::BuildKernelPrototype( @@ -232,70 +245,24 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { return IrEmitter::DefaultAction(hlo); } -Status IrEmitterUnnested::HandleDot(HloInstruction* dot, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction) { +Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); } thunk_sequence_->emplace_back(BuildKernelThunk(dot)); - return IrEmitter::HandleDot(dot, lhs_instruction, rhs_instruction); + return IrEmitter::HandleDot(dot); } -Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction, - const Window& window) { +Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) { if (ImplementedAsDnnConvolution(*convolution)) { thunk_sequence_->emplace_back(BuildConvolutionThunk(convolution)); return Status::OK(); } thunk_sequence_->emplace_back(BuildKernelThunk(convolution)); - return IrEmitter::HandleConvolution(convolution, lhs_instruction, - rhs_instruction, window); -} - -namespace { - -// Returns the first non-GetTupleElement ancestor instruction of 'hlo'. -// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the -// (possibly nested) tuple indices used on the path from ancestor to 'hlo'. -const HloInstruction* LatestNonGteAncestorAndIndex(const HloInstruction* hlo, - ShapeIndex* index) { - if (hlo->opcode() == HloOpcode::kGetTupleElement) { - const auto* operand = LatestNonGteAncestorAndIndex(hlo->operand(0), index); - index->push_back(hlo->tuple_index()); - return operand; - } - return hlo; + return IrEmitter::HandleConvolution(convolution); } -// Checks if we can emit code for DynamicUpdateSlice to update data in-place. -// Returns true if operand 0 of DynamicUpdateSlice and its output buffer -// share the same buffer allocation. -// Returns false otherwise. -bool CanUpdateDynamicSliceInPlace(const BufferAssignment& assignment, - HloInstruction* fusion) { - CHECK_EQ(HloOpcode::kFusion, fusion->opcode()); - HloInstruction* fused_root = fusion->fused_expression_root(); - if (fused_root->opcode() != HloOpcode::kDynamicUpdateSlice) { - return false; - } - // Walk DynamicUpdateSlice operand(0) to fused parameter and get its - // associated operand. See if it shares an allocation with this operand. - ShapeIndex index; - auto* fusion_operand = - LatestNonGteAncestorAndIndex(fused_root->operand(0), &index); - if (fusion_operand->opcode() != HloOpcode::kParameter) { - return false; - } - auto* operand = fusion->operand(fusion_operand->parameter_number()); - return assignment.SharesSliceAtIndex(fusion, {}, operand, index); -} - -} // namespace - Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { HloInstruction* root = fusion->fused_expression_root(); // HandleFusion specializes reduction from a multi-dimensional array to a 1D @@ -366,95 +333,40 @@ Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) { LOG(FATAL) << "Bad opcode for input fusion: " << fusion->fused_expression_root()->opcode(); } - } else if (HloInstruction::FusionKind::kLoop == fusion->fusion_kind() && - root->opcode() == HloOpcode::kDynamicUpdateSlice && - CanUpdateDynamicSliceInPlace( - ir_emitter_context_->buffer_assignment(), fusion)) { - // Loop fusion instruction with DynamicUpdateSlice as fused root. - // DynamicUpdateSlice's operand(0) and 'fusion' output share the same - // BufferAllocation::Slice, so it is safe to emit code to update the slice - // 'in-place'. This avoids copying data outside of the slice update region. + } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace( + fusion, ir_emitter_context_->buffer_assignment())) { + // Fusion node with dynamic-update-slice as the root where the op's input + // (i.e. array to update) shares the same slice as its output. In this case + // we have a special algorithm that modifies the output in place without + // touching the un-updated elements. // Set up kernel thunk and fused ir emitter. thunk_sequence_->emplace_back(BuildKernelThunk(fusion)); - std::vector parameter_arrays; + std::vector operand_arrays; for (HloInstruction* operand : fusion->operands()) { - parameter_arrays.push_back(GetIrArray(*operand)); + operand_arrays.push_back(GetIrArray(*operand)); } GpuElementalIrEmitter elemental_emitter(hlo_module_config_, ir_emitter_context_->llvm_module(), &ir_builder_, GetNestedComputer()); - FusedIrEmitter fused_emitter(parameter_arrays, &elemental_emitter); - TF_RETURN_IF_ERROR(root->Accept(&fused_emitter)); - - // Recursively lookup 'fusion_operand' for DynamicUpdateSlice operand 0. - auto* fusion_operand = LatestNonGteAncestor(root->operand(0)); - CHECK_EQ(HloOpcode::kParameter, fusion_operand->opcode()); - - // Operand(0) the input array which shares an allocation with the output. - const auto* input = root->operand(0); - llvm::Value* input_base_ptr = fused_emitter.GetIrValueForGTE(input); - // Operand(1) 'update' is slice with which to update input at operand(0). - const auto* update = root->operand(1); - Shape update_shape = update->shape(); - TF_RETURN_IF_ERROR( - LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); - // Operand(2) the dynamic slice indices at which to write 'update'. - const auto* start_indices = root->operand(2); - - // Create element generators for 'update' and 'start_indices'. - llvm_ir::ElementGenerator element_generator = - fused_emitter.GetGenerator(update); - llvm_ir::ElementGenerator start_generator = - fused_emitter.GetGenerator(start_indices); - - // Create loop body emitter which emits code to do the following: - // *) Read dynamic slice start indices into 'start_index'. - // *) Map requested 'index' and slice 'start_index' to input/output shape - // as 'output_index'. - // *) Reads value from 'update' element generator. - // *) Writes value to input/output array at 'output_index'. - auto loop_body_emitter = - [=](const llvm_ir::IrArray::Index& index) -> Status { - // Emit IR to read dynamic start indices from hlo->operand(2). - const int64 rank = ShapeUtil::Rank(input->shape()); - llvm_ir::IrArray::Index start_index(rank); - for (int64 i = 0; i < rank; ++i) { - llvm_ir::IrArray::Index dim_index({ir_builder_.getInt64(i)}); - TF_ASSIGN_OR_RETURN(start_index[i], start_generator(dim_index)); - } - // Calculate 'output_index' at which to write value from update. - llvm_ir::IrArray::Index output_index(rank); - for (int64 i = 0; i < rank; ++i) { - // Emit IR which computes: - // output_index = (start_index + index) % dim_size - llvm::Value* dim_size = llvm::ConstantInt::get( - index[i]->getType(), input->shape().dimensions(i)); - llvm::Value* start_index0 = ir_builder_.CreateZExtOrBitCast( - start_index[i], index[i]->getType()); - output_index[i] = ir_builder_.CreateURem( - ir_builder_.CreateAdd(start_index0, index[i]), dim_size); - } + // Shape of the dynamic-update-slice's "update" operand. + Shape update_shape = root->operand(1)->shape(); - // Read value from 'update'. - TF_ASSIGN_OR_RETURN(llvm::Value * input_value, element_generator(index)); - // Write value to output array. - llvm_ir::IrArray(input_base_ptr, input->shape()) - .EmitWriteArrayElement(output_index, input_value, &ir_builder_); - return Status::OK(); - }; + // Array to write into. Because this is an in-place operation, this is the + // same as operand 0's array. + llvm_ir::IrArray output_array = GetIrArray(*fusion); - // Create loop which iterates over 'update' shape. LaunchDimensions launch_dimensions = CalculateLaunchDimensions( update_shape, ir_emitter_context_->device_description()); CHECK(Thunk::Kind::kKernel == LastThunk()->kind()); UpdateLaunchDimensions(launch_dimensions, static_cast(LastThunk()), ir_emitter_context_->llvm_module()); - return ParallelLoopEmitter(loop_body_emitter, update_shape, - launch_dimensions, &ir_builder_) - .EmitLoop(IrName(fusion)); + + return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace( + fusion, operand_arrays, output_array, &elemental_emitter, + launch_dimensions, &ir_builder_); } if (ImplementedAsGemm(*fusion)) { thunk_sequence_->emplace_back(BuildGemmThunk(fusion)); @@ -760,8 +672,13 @@ int64 EmitTranspose021Tiled(llvm_ir::IrArray input, llvm_ir::IrArray output, } // namespace Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) { - if (ImplementedAsMemcpy(*copy)) { - thunk_sequence_->emplace_back(BuildCopyThunk(copy)); + if (ImplementedAsHostToDeviceMemcpy(*copy)) { + thunk_sequence_->emplace_back(BuildHostToDeviceCopyThunk(copy)); + return Status::OK(); + } + if (ImplementedAsDeviceToDeviceMemcpy( + ir_emitter_context_->buffer_assignment(), *copy)) { + thunk_sequence_->emplace_back(BuildDeviceToDeviceCopyThunk(copy)); return Status::OK(); } bool is_transpose_021; @@ -834,8 +751,8 @@ Status IrEmitterUnnested::EmitColumnReduction( auto loop_body_emitter = [=](const llvm_ir::IrArray::Index& tile_index) -> Status { // Emit the loop body that reduces one tile. - llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), &ir_builder_); + llvm::Type* element_ir_type = + llvm_ir::PrimitiveTypeToIrType(input_shape.element_type(), module_); llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); { @@ -1050,7 +967,7 @@ Status IrEmitterUnnested::EmitRowReduction( [=](const llvm_ir::IrArray::Index& tile_index) -> Status { // Emit the loop body that reduces one tile. llvm::Type* element_ir_type = llvm_ir::PrimitiveTypeToIrType( - input_shape.element_type(), &ir_builder_); + input_shape.element_type(), ir_emitter_context_->llvm_module()); llvm::Value* partial_reduction_result_address = ir_builder_.CreateAlloca( element_ir_type, /*ArraySize=*/nullptr, "partial_reduction_result"); { @@ -1311,10 +1228,11 @@ Status IrEmitterUnnested::EmitReductionToVector( } } -Status IrEmitterUnnested::HandleReduce( - HloInstruction* reduce, HloInstruction* input, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions_to_reduce, - HloComputation* reducer) { +Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) { + auto input = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions_to_reduce(reduce->dimensions()); + HloComputation* reducer = reduce->to_apply(); // HandleReduce specializes reduction from a multi-dimensional array to a 1D // array. The specialized version requires an initializer thunk that // initializes the output array to the initial value of the reduce. @@ -1342,13 +1260,11 @@ Status IrEmitterUnnested::HandleReduce( } thunk_sequence_->emplace_back(BuildKernelThunk(reduce)); - return IrEmitter::HandleReduce(reduce, input, init_value, - dimensions_to_reduce, reducer); + return IrEmitter::HandleReduce(reduce); } -Status IrEmitterUnnested::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) { + tensorflow::gtl::ArraySlice operands(tuple->operands()); bool all_tuple_elements_have_buffer = std::all_of( operands.begin(), operands.end(), [this](HloInstruction* tuple_element) { return ir_emitter_context_->buffer_assignment().HasTopLevelAllocation( @@ -1372,19 +1288,11 @@ Status IrEmitterUnnested::HandleTuple( tuple_element_buffers, GetAllocationSlice(*tuple), tuple)); return Status::OK(); } - // If `inst` is a nested thunk that can be disassembled from the result tuple, - // GpuExecutable will disassemble it and return it as part of the resultant - // ShapedBuffer. - if (has_hybrid_result_ && - ReachRootViaOnlyTuples(*tuple, *hlo_computation_->root_instruction())) { - return Status::OK(); - } thunk_sequence_->emplace_back(BuildKernelThunk(tuple)); - return IrEmitter::HandleTuple(tuple, operands); + return IrEmitter::HandleTuple(tuple); } -Status IrEmitterUnnested::HandleGetTupleElement( - HloInstruction* get_tuple_element, HloInstruction* operand) { +Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) { // GetTupleElement IR is emitted in the IR context of the user instruction, // and so we do not build a kernel for GetTupleElement instructions. return Status::OK(); @@ -1444,7 +1352,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // boolean flag if the value is initialized. The initialized_flag is set // false. llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(operand_element_type, + ir_emitter_context_->llvm_module()), "selected_value_address", &ir_builder_); llvm::Value* selected_index_address = llvm_ir::EmitAllocaAtFunctionEntryWithCount( @@ -1524,7 +1433,8 @@ Status IrEmitterUnnested::HandleSelectAndScatter( llvm::Value* operand_address = operand_array.EmitArrayElementAddress(operand_index, &ir_builder_); llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), + llvm_ir::PrimitiveTypeToIrType(PRED, + ir_emitter_context_->llvm_module()), "select_return_buffer", &ir_builder_); TF_RETURN_IF_ERROR(EmitCallToNestedComputation( *select_and_scatter->select(), @@ -1534,8 +1444,10 @@ Status IrEmitterUnnested::HandleSelectAndScatter( // If the 'select' function returns false, update the selected value and the // index to the currently visiting operand. llvm::Value* cond = ir_builder_.CreateICmpNE( - result, llvm::ConstantInt::get( - llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0), + result, + llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType( + PRED, ir_emitter_context_->llvm_module()), + 0), "boolean_predicate"); llvm_ir::LlvmIfData if_select_lhs = llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_); @@ -1605,18 +1517,14 @@ Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) { return Status::OK(); } -Status IrEmitterUnnested::HandleRng(HloInstruction* random, - RandomDistribution distribution) { +Status IrEmitterUnnested::HandleRng(HloInstruction* random) { thunk_sequence_->push_back(BuildKernelThunk(random)); - return IrEmitter::HandleRng(random, distribution); + return IrEmitter::HandleRng(random); } -Status IrEmitterUnnested::HandleSelect(HloInstruction* select, - HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) { +Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { thunk_sequence_->push_back(BuildKernelThunk(select)); - return IrEmitter::HandleSelect(select, pred, on_true, on_false); + return IrEmitter::HandleSelect(select); } Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { @@ -1634,7 +1542,7 @@ llvm::Function* IrEmitterUnnested::EmitBasePointersForHloAndItsOperands( // with their operand buffer in 'io_hlos' and 'non_io_hlos' below. std::vector non_io_hlos; for (const HloInstruction* operand : hlo.operands()) { - const HloInstruction* to_lookup = LatestNonGteAncestor(operand); + const HloInstruction* to_lookup = operand->LatestNonGteAncestor(); if (buffer_assignment.HasTopLevelAllocation(to_lookup) && buffer_assignment.GetUniqueTopLevelSlice(to_lookup) .ConsumeValueOrDie() @@ -1674,7 +1582,7 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( std::vector io_buffers; io_buffers.reserve(io_hlos.size()); for (const HloInstruction* io_hlo : io_hlos) { - io_buffers.push_back(GetAllocationSlice(*LatestNonGteAncestor(io_hlo))); + io_buffers.push_back(GetAllocationSlice(*io_hlo->LatestNonGteAncestor())); } // Create a KernelThunk that launches the kernel that implements "inst". @@ -1682,11 +1590,11 @@ std::unique_ptr IrEmitterUnnested::BuildKernelThunk( llvm_ir::AsString(kernel->getName()), inst); } -std::unique_ptr IrEmitterUnnested::BuildCopyThunk( +std::unique_ptr IrEmitterUnnested::BuildHostToDeviceCopyThunk( const HloInstruction* inst) { const HloInstruction* operand = inst->operand(0); CHECK_EQ(HloOpcode::kConstant, operand->opcode()); - return MakeUnique( + return MakeUnique( /*source_address=*/operand->literal().InternalData(), /*destination_buffer=*/GetAllocationSlice(*inst), /*mem_size=*/ @@ -1695,6 +1603,18 @@ std::unique_ptr IrEmitterUnnested::BuildCopyThunk( inst); } +std::unique_ptr IrEmitterUnnested::BuildDeviceToDeviceCopyThunk( + const HloInstruction* inst) { + const HloInstruction* operand = inst->operand(0); + return MakeUnique( + /*source_address=*/GetAllocationSlice(*operand), + /*destination_buffer=*/GetAllocationSlice(*inst), + /*mem_size=*/ + llvm_ir::ByteSizeOf(operand->shape(), + ir_emitter_context_->llvm_module()->getDataLayout()), + inst); +} + std::unique_ptr IrEmitterUnnested::BuildInfeedThunk( const HloInstruction* inst) { CHECK_EQ(HloOpcode::kInfeed, inst->opcode()); @@ -1888,14 +1808,12 @@ std::unique_ptr IrEmitterUnnested::BuildWhileThunk( // Generate thunk sequence for while 'condition'. HloComputation* condition = hlo->while_condition(); IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition, - /*has_hybrid_result=*/false, ir_emitter_context_); TF_CHECK_OK(condition->root_instruction()->Accept(&ir_emitter_condition)); // Generate thunk sequence for while 'body'. HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - false /* has_hybrid_result */, ir_emitter_context_); TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); @@ -1914,7 +1832,6 @@ std::unique_ptr IrEmitterUnnested::BuildForThunk( // Generate thunk sequence for while 'body' (will be used a For loop body). HloComputation* body = hlo->while_body(); IrEmitterUnnested ir_emitter_body(hlo_module_config_, body, - false /* has_hybrid_result */, ir_emitter_context_); TF_CHECK_OK(body->root_instruction()->Accept(&ir_emitter_body)); @@ -1952,7 +1869,8 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk( tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer()); } ir_builder_.SetInsertPoint(ir_builder_.GetInsertBlock()->getTerminator()); - llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_); + llvm_ir::EmitTuple(GetIrArray(hlo), tuple_operand_ptrs, &ir_builder_, + module_); return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc index 66cc7b3e40d7cd7f71d1fb72305e105a86c438ad..0bbd63fb7bfc657cb7bb1de673253c198f5bd25f 100644 --- a/tensorflow/compiler/xla/service/gpu/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/layout_assignment.cc @@ -30,7 +30,7 @@ namespace gpu { Status GpuLayoutAssignment::AddBackendConstraints( LayoutConstraints* constraints) { - for (auto& instruction : constraints->computation()->instructions()) { + for (auto* instruction : constraints->computation()->instructions()) { // cuDNN is called with specific layouts on the input, output, and filter: // // input: DataLayout::kBatchDepthYX @@ -51,19 +51,19 @@ Status GpuLayoutAssignment::AddBackendConstraints( if (instruction->opcode() == HloOpcode::kConvolution) { input = instruction->mutable_operand(0); filter = instruction->mutable_operand(1); - output = instruction.get(); + output = instruction; } else { CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); switch (instruction->fusion_kind()) { case HloInstruction::FusionKind::kConvBackwardFilter: // filter = BackwardFilterConvolve(input, output) input = instruction->mutable_operand(0); - filter = instruction.get(); + filter = instruction; output = instruction->mutable_operand(1); break; case HloInstruction::FusionKind::kConvBackwardInput: // input = BackwardInputConvolve(output, filter) - input = instruction.get(); + input = instruction; filter = instruction->mutable_operand(1); output = instruction->mutable_operand(0); break; @@ -84,8 +84,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { input_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - input_layout.push_back(dimension_numbers.feature_dimension()); - input_layout.push_back(dimension_numbers.batch_dimension()); + input_layout.push_back(dimension_numbers.input_feature_dimension()); + input_layout.push_back(dimension_numbers.input_batch_dimension()); Shape input_shape(input->shape()); *input_shape.mutable_layout() = LayoutUtil::MakeLayout(input_layout); @@ -106,8 +106,8 @@ Status GpuLayoutAssignment::AddBackendConstraints( --i) { output_layout.push_back(dimension_numbers.spatial_dimensions(i)); } - output_layout.push_back(dimension_numbers.feature_dimension()); - output_layout.push_back(dimension_numbers.batch_dimension()); + output_layout.push_back(dimension_numbers.output_feature_dimension()); + output_layout.push_back(dimension_numbers.output_batch_dimension()); Shape output_shape(output->shape()); *output_shape.mutable_layout() = LayoutUtil::MakeLayout(output_layout); diff --git a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc index 81cca312982a3a5ee98b3914447f2d878354c3a5..817e95a31c546076364674fad63cdb54c3d0e147 100644 --- a/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc +++ b/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/gpu_backend_lib.cc @@ -342,6 +342,13 @@ StatusOr CompileModuleToPtx(llvm::Module* module, std::pair compute_capability, const HloModuleConfig& hlo_module_config, const string& libdevice_dir_path) { + // If the module has no functions or globals, there's nothing to compile. Just + // return an empty string. + if (module->empty() && module->global_empty()) { + VLOG(2) << "Module '" << llvm_ir::AsString(module->getName()) + << "' is empty. Skipping compilation."; + return string(); + } // Link the input module with libdevice, to pull in implementations of some // builtins. TF_RETURN_IF_ERROR( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index af853385d634b06d31cef94216fb4059dfcadc3d..79493c4112804f8454d200f3f83aa85d718f0d0a 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -39,6 +39,8 @@ message HloInstructionProto { string name = 1; string opcode = 2; xla.Shape shape = 3; + + // TODO(b/67782397): Replace instruction names with HloInstruction ids. repeated string operand_names = 4; repeated string control_predecessor_names = 5; repeated string called_computation_names = 6; @@ -58,6 +60,64 @@ message HloInstructionProto { // Index for kGetTupleElement. int64 tuple_index = 13; + + // Dimensions present for some operations that require reshaping or + // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. + repeated int64 dimensions = 14; + + // Describes the window in a windowed operation such as convolution. + xla.Window window = 15; + + // Describes the dimension numbers used for a convolution. + xla.ConvolutionDimensionNumbers convolution_dimension_numbers = 16; + + // Describes the [begin, end) index range and stride for slices. + message SliceDimensions { + int64 start = 1; + int64 limit = 2; + int64 stride = 3; + } + repeated SliceDimensions slice_dimensions = 17; + + // The bit sizes for a reduce-precision operation. + int32 exponent_bits = 18; + int32 mantissa_bits = 19; + + // Describes the [start, start + size) range size for a dynamic slice + // ('start' is specified dynamically in the second operand of the operation). + repeated int64 dynamic_slice_sizes = 20; + + // The padding configuration that describes the edge padding and interior + // padding of this pad instruction. Only set for pad instructions. + xla.PaddingConfig padding_config = 21; + + // Outfeed configuration information, only present for kOutfeed. + bytes outfeed_config = 22; + + // The distribution requested for random number generation. + // Only present for kRng. + xla.RandomDistribution distribution = 23; + + // A small float number added to the variance to avoid divide-by-zero error. + // Only present for kBatchNormTraining. + float epsilon = 24; + + // An integer value representing the index of the feature dimension. + // Only present for kBatchNormTraining. + int64 feature_index = 25; + + // Represents a unique identifier for each Send/Recv instruction pair. + // Only present for kSend or kRecv. + int64 channel_id = 26; + + // The string representation of the infeed configuration. + bytes infeed_config = 27; + + // Name of a global symbol to call, only present for kCustomCall. + string custom_call_target = 28; + + // Shape of outfeed request. + xla.Shape outfeed_shape = 29; } // Serialization of HloComputation. @@ -67,6 +127,9 @@ message HloComputationProto { // The array of instructions is always in a valid dependency order, where // operands appear before their users. repeated HloInstructionProto instructions = 2; + + // The name of the root of the computation. + string root_name = 3; } // Serialization of HloModule. @@ -187,3 +250,7 @@ message HloProto { HloOrderingProto hlo_ordering = 2; BufferAssignmentProto buffer_assignment = 3; } + +message HloProtos { + repeated HloProto hlo_protos = 1; +} diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc index 3dd8ac6dc5fa46b80328e080e6d1b4e8c402e8b0..6f8099475146e6bbcfb61d2e5a91a7a6f9e63e58 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc @@ -34,6 +34,7 @@ limitations under the License. namespace xla { +using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; @@ -373,10 +374,8 @@ Status HloAliasAnalysis::Verify() const { string HloAliasAnalysis::ToString() const { string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Buffers at each position:\n"); - for (const std::unique_ptr& computation : - module_->computations()) { - for (const std::unique_ptr& instruction : - computation->instructions()) { + for (const HloComputation* computation : module_->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { ShapeUtil::ForEachSubshape( @@ -384,13 +383,13 @@ string HloAliasAnalysis::ToString() const { [&out, &instruction, this](const Shape&, const ShapeIndex& index) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); for (const HloBuffer* buffer : - ComputeBuffersAt(instruction.get(), index)) { + ComputeBuffersAt(instruction, index)) { StrAppend(&out, " ", buffer->ToString(), "\n"); } }); } else { for (const HloBuffer* buffer : - ComputeBuffersAt(instruction.get(), /*index=*/{})) { + ComputeBuffersAt(instruction, /*index=*/{})) { StrAppend(&out, " ", buffer->ToString(), "\n"); } } @@ -449,4 +448,56 @@ StatusOr> HloAliasAnalysis::Run( return std::move(alias_analysis); } +bool HloAliasAnalysis::HasLiveRangeInterference( + const HloOrdering& ordering) const { + for (const HloBuffer& buffer : buffers()) { + // Check that the values in the buffer are totally ordered with respect to + // 'ordering'. Begin by sorting the values with respect to 'ordering' with a + // tie-break using value ID. The tie-break is necessary because we need a + // strict weak order for std::sort. + std::vector values = buffer.values(); + std::sort(values.begin(), values.end(), + [&ordering](const HloValue* a, const HloValue* b) { + if (ordering.IsDefinedBefore(*a, *b)) { + return true; + } else if (ordering.IsDefinedBefore(*b, *a)) { + return false; + } else { + return a->id() < b->id(); + } + }); + + // Walk through the ordered vector of values. First verify that the values + // are totally ordered with respect to 'ordering', then check that no + // adjacent values have overlapping live ranges. Only adjacent values must + // be checked because of the property of live range interference. For + // example, if you have values A, B, and C (in program order) contained in + // a buffer and A interferes with C, then necessarily A also interferes + // with B. So to check interference you only need to check interference + // between A and B, and between B and C. + CHECK(!values.empty()); + for (int i = 1; i < values.size(); ++i) { + if (!ordering.IsDefinedBefore(*values[i - 1], *values[i])) { + VLOG(1) << values[i - 1]->ToShortString() << " and " + << values[i]->ToShortString() << " are not ordered"; + return true; + } + if (ordering.MayInterfere(*values[i - 1], *values[i], + dataflow_analysis())) { + VLOG(1) << "In buffer " << buffer.id() << " containing values:\n " + << Join(values, ", ", + [](string* out, const HloValue* value) { + StrAppend(out, value->ToShortString()); + }) + + << "\nValue " << values[i - 1]->ToShortString() + << " may interfere with value " << values[i]->ToShortString(); + return true; + } + } + } + + return false; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h index 39554e466488007bfca666b5453ebaa555f598bf..67dfd4301b3a027a496911ecf6f06841dfd6423a 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h @@ -24,6 +24,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/status.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -90,10 +91,9 @@ class HloAliasAnalysis { // output of the given instruction. bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const; - // Compare the dataflow analysis against a clean recomputation of the - // analysis. Returns an error status if there is a mismatch. Useful for - // verifying the correctness after updates to the analysis. - Status VerifyAgainstReference() const; + // Returns true if any HLO values in the module have interfering live ranges + // assuming the given ordering. + bool HasLiveRangeInterference(const HloOrdering& ordering) const; protected: explicit HloAliasAnalysis(HloModule* module); diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc index a275628779bfd737812b81a8d254d8f2a144c9d2..8f18d50f6e033fab1c01f42017b951c224c22799 100644 --- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc @@ -820,5 +820,83 @@ TEST_F(HloAliasAnalysisTest, Bitcast) { analysis.GetUniqueBufferAt(bitcast)); } +TEST_F(HloAliasAnalysisTest, BitcastInterference) { + // A bitcast value simultaneously live with its operand should not cause + // interference. + auto builder = HloComputation::Builder(TestName()); + auto constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + auto bitcast = builder.AddInstruction(HloInstruction::CreateUnary( + scalar_shape_, HloOpcode::kBitcast, constant)); + builder.AddInstruction(HloInstruction::CreateTuple({constant, bitcast})); + + module_->AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + DependencyHloOrdering ordering(module_.get()); + EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); +} + +TEST_F(HloAliasAnalysisTest, WhileInterference) { + // Build a while loop which has a parallel use of the init value. Depending on + // ordering there may be interference between the update-in-place while and + // the other use of the init. + auto builder = HloComputation::Builder(TestName()); + auto init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1.0))); + + auto cond_builder = HloComputation::Builder("condition"); + auto cond_param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, init->shape(), "param")); + auto cond_root = cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(false))); + HloComputation* condition = + module_->AddEmbeddedComputation(cond_builder.Build()); + + auto body_builder = HloComputation::Builder("body"); + auto body_param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, init->shape(), "param")); + auto body_root = body_builder.AddInstruction( + HloInstruction::CreateUnary(init->shape(), HloOpcode::kExp, body_param)); + HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build()); + + auto xla_while = builder.AddInstruction( + HloInstruction::CreateWhile(init->shape(), condition, body, init)); + + auto negate = builder.AddInstruction( + HloInstruction::CreateUnary(init->shape(), HloOpcode::kNegate, init)); + auto entry_root = + builder.AddInstruction(HloInstruction::CreateTuple({negate, xla_while})); + + HloComputation* entry = module_->AddEntryComputation(builder.Build()); + + const HloAliasAnalysis& analysis = RunAnalysis(); + + { + // Dependency ordering should interfere because the negate and while are + // unordered. + DependencyHloOrdering ordering(module_.get()); + EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); + } + + // For a sequential order, if there is interference iff the negate is after + // the while. + SequentialHloOrdering::HloModuleSequence sequence; + sequence[body] = {body_param, body_root}; + sequence[condition] = {cond_param, cond_root}; + { + sequence[entry] = {init, xla_while, negate, entry_root}; + SequentialHloOrdering ordering(module_.get(), sequence); + EXPECT_TRUE(analysis.HasLiveRangeInterference(ordering)); + } + + { + sequence[entry] = {init, negate, xla_while, entry_root}; + SequentialHloOrdering ordering(module_.get(), sequence); + EXPECT_FALSE(analysis.HasLiveRangeInterference(ordering)); + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 2d077846196bdaf5183f6ee43ab582ede4ef4f52..8f595b45e9832376c4ef881065207f70d2501bee 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -56,7 +56,6 @@ std::unique_ptr HloComputation::Builder::Build( HloInstruction* root = root_instruction ? root_instruction : last_added_instruction_; CHECK_NE(nullptr, root); - return WrapUnique(new HloComputation(name_, parameter_count, &instructions_, root, fusion_instruction_)); } @@ -185,7 +184,7 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) { } bool HloComputation::HasSideEffect() const { - for (auto& instruction : instructions()) { + for (auto* instruction : instructions()) { if (instruction->HasSideEffect()) { return true; } @@ -198,7 +197,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands( TF_RET_CHECK(root_instruction() != instruction); TF_RET_CHECK(instruction->user_count() == 0); - TF_RET_CHECK(IsRemovable(instruction)); + TF_RET_CHECK(IsRemovable(instruction)) + << "Cannot remove instruction: " << instruction->ToString(); std::unordered_set removed; std::queue worklist; worklist.push(instruction); @@ -245,15 +245,6 @@ Status HloComputation::RemoveInstruction(HloInstruction* instruction) { return Status::OK(); } -Status HloComputation::ReplaceUsesOfInstruction( - HloInstruction* instruction_to_replace, HloInstruction* instruction) { - TF_RETURN_IF_ERROR(instruction_to_replace->ReplaceAllUsesWith(instruction)); - if (instruction_to_replace == root_instruction()) { - set_root_instruction(instruction); - } - return Status::OK(); -} - void HloComputation::set_root_instruction( HloInstruction* new_root_instruction) { // The shape of the root (ignoring layout) is an invariant of the computation @@ -323,7 +314,7 @@ void ComputeComputationPostOrder( return; } - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : instruction->called_computations()) { ComputeComputationPostOrder(called_computation, visited, post_order); @@ -376,21 +367,27 @@ std::list HloComputation::MakeEmbeddedComputationsList() return post_order; } -string HloComputation::ToString(int nested_level) const { +string HloComputation::ToString(int nested_level, + bool include_large_constants) const { std::ostringstream s; for (int i = 0; i < nested_level; i++) { s << " "; } - s << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) - << " { \n"; + s << "%" << name() << " " << ShapeUtil::HumanString(ComputeProgramShape()) + << " {\n"; for (const HloInstruction* instruction : MakeInstructionPostOrder()) { for (int i = 0; i < nested_level; i++) { s << " "; } - s << " " << instruction->ToString() << "\n"; + s << " " << (instruction == root_instruction_ ? "ROOT " : "") + << instruction->ToString( + /*compact_operands=*/false, + /*include_metadata=*/true, + /*include_large_constants=*/include_large_constants) + << "\n"; if (instruction->opcode() == HloOpcode::kFusion) { s << instruction->fused_instructions_computation()->ToString( - nested_level + 1) + nested_level + 1, include_large_constants) << "\n"; } } @@ -408,9 +405,38 @@ HloComputationProto HloComputation::ToProto() const { HloInstructionProto instruction_proto = instruction->ToProto(); proto.add_instructions()->Swap(&instruction_proto); } + proto.set_root_name(root_instruction()->name()); return proto; } +/* static */ StatusOr> +HloComputation::CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction) { + std::vector> instructions; + tensorflow::gtl::FlatMap instruction_map; + int64 parameter_count = 0; + for (const HloInstructionProto& instruction_proto : proto.instructions()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr instruction, + HloInstruction::CreateFromProto(module, instruction_proto, + instruction_map, computation_map)); + if (instruction->opcode() == HloOpcode::kParameter) { + parameter_count++; + } + TF_RET_CHECK(!ContainsKey(instruction_map, instruction->name())); + instruction_map[instruction->name()] = instruction.get(); + instructions.push_back(std::move(instruction)); + } + + TF_RET_CHECK(!proto.root_name().empty()); + TF_RET_CHECK(ContainsKey(instruction_map, proto.root_name())); + HloInstruction* root = instruction_map.at(proto.root_name()); + return WrapUnique(new HloComputation( + proto.name(), parameter_count, &instructions, root, fusion_instruction)); +} + void HloComputation::FuseInstructionsInto( tensorflow::gtl::ArraySlice instructions_to_fuse, HloInstruction* fusion_instruction) { @@ -569,8 +595,7 @@ Status HloComputation::ReplaceInstruction(HloInstruction* old_instruction, if (new_instruction->metadata().op_name().empty()) { new_instruction->set_metadata(old_instruction->metadata()); } - TF_RETURN_IF_ERROR( - ReplaceUsesOfInstruction(old_instruction, new_instruction)); + TF_RETURN_IF_ERROR(old_instruction->ReplaceAllUsesWith(new_instruction)); return RemoveInstructionAndUnusedOperands(old_instruction); } @@ -618,11 +643,11 @@ void HloComputation::UpdateReachabilityThroughInstruction( std::vector HloComputation::CollectUnreachableRoots() const { std::vector unreachable_roots; - for (auto& instruction : instructions()) { + for (auto* instruction : instructions()) { if (instruction->user_count() == 0 && instruction->control_successors().empty() && - instruction.get() != root_instruction()) { - unreachable_roots.push_back(instruction.get()); + instruction != root_instruction()) { + unreachable_roots.push_back(instruction); } } VLOG(3) << "Unreachable roots:" @@ -634,7 +659,9 @@ std::vector HloComputation::CollectUnreachableRoots() const { return unreachable_roots; } -Status HloComputation::Accept(DfsHloVisitor* visitor) const { +template +Status HloComputation::Accept( + DfsHloVisitorBase* visitor) const { // Visit unreachable roots. Beware that the visitor might delete the currently // visited root, which would invalidate iterators if the unreachable roots // weren't computed ahead of time. @@ -647,6 +674,10 @@ Status HloComputation::Accept(DfsHloVisitor* visitor) const { return root_instruction()->Accept(visitor, /*call_finish_visit=*/true); } +// Explicit instantiations. +template Status HloComputation::Accept(DfsHloVisitor* visitor) const; +template Status HloComputation::Accept(ConstDfsHloVisitor* visitor) const; + Status HloComputation::AcceptWithOperandOrder( DfsHloVisitor* visitor, const HloInstruction::CompareFunction& operand_order) const { @@ -663,8 +694,9 @@ Status HloComputation::AcceptWithOperandOrder( /*call_finish_visit=*/true); } +template Status HloComputation::AcceptOrdered( - DfsHloVisitor* visitor, + DfsHloVisitorBase* visitor, const std::vector& order) const { VLOG(3) << "Accepting visitor with order."; for (HloInstruction* root : CollectUnreachableRoots()) { @@ -693,45 +725,111 @@ Status HloComputation::AcceptOrdered( return Status::OK(); } +// Explicit instantiations. +template Status HloComputation::AcceptOrdered( + DfsHloVisitor*, const std::vector&) const; +template Status HloComputation::AcceptOrdered( + ConstDfsHloVisitor*, const std::vector&) const; + Status HloComputation::Accept( - const FunctionVisitor::VisitorFunction& visitor_func) const { + const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } -std::unique_ptr HloComputation::Clone(const string& suffix) { +Status HloComputation::Accept( + const std::function& visitor_func) const { + ConstFunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + +std::unique_ptr HloComputation::Clone(const string& suffix, + HloModule* module) { + return CloneWithReplacements( + /*replacements=*/std::unordered_map>(), + module, suffix); +} + +std::unique_ptr HloComputation::CloneWithReplacements( + std::unordered_map> + replacements, + HloModule* module, const string& suffix) { + // Look up instr in the replacements map, and return either the replacement, + // or instr, if the replacement isn't present. + // + // Note: This can return null, indicating that instr should not be present in + // the new computation. + auto replace = [&](HloInstruction* instr) { + auto it = replacements.find(instr); + if (it == replacements.end()) { + return instr; + } + return it->second.get(); + }; + VLOG(1) << "Cloning " << name() << " --> " << suffix << "\n"; - auto postorder = MakeInstructionPostOrder(); + std::vector postorder; + for (HloInstruction* instr : MakeInstructionPostOrder()) { + if (HloInstruction* replacement = replace(instr)) { + postorder.push_back(replacement); + } + } + std::unordered_map clone_map; std::vector> instructions; std::unique_ptr new_instr = nullptr; for (auto instr : postorder) { std::vector new_operands; for (auto operand : instr->operands()) { - HloInstruction* new_operand = FindOrDie(clone_map, operand); - CHECK(new_operand != nullptr); - new_operands.push_back(new_operand); + auto replaced_operand = replace(operand); + // If replaced_operand is null, that means 'replacements' asked us not to + // include operand in the new computation. But we can't do that, because + // operand is used by instr. + CHECK_NE(replaced_operand, nullptr) + << "replacements map tried to eliminate a used instruction " + << operand->ToString() << ", used by " << instr->ToString(); + new_operands.push_back(FindOrDie(clone_map, replaced_operand)); } - - new_instr = instr->CloneWithNewOperands(instr->shape(), new_operands); + new_instr = + instr->CloneWithNewOperands(instr->shape(), new_operands, module); InsertOrDie(&clone_map, instr, new_instr.get()); instructions.push_back(std::move(new_instr)); } - Builder builder(name() + suffix); + Builder builder(name() + "." + suffix); for (auto& instr : instructions) { builder.AddInstruction(std::move(instr)); } auto result = builder.Build( - /*root_instruction=*/FindOrDie(clone_map, root_instruction())); + /*root_instruction=*/FindOrDie(clone_map, replace(root_instruction()))); // Clone control dependencies. for (auto instr : postorder) { HloInstruction* new_instr = FindOrDie(clone_map, instr); for (auto successor : instr->control_successors()) { - TF_CHECK_OK( - new_instr->AddControlDependencyTo(FindOrDie(clone_map, successor))); + auto replaced_successor = replace(successor); + + // successor may not be in clone_map, because it might have been + // removed by the replacements map. + if (replaced_successor == nullptr) { + continue; + } + + TF_CHECK_OK(new_instr->AddControlDependencyTo( + FindOrDie(clone_map, replaced_successor))); } } + + // We cloned the elements of 'replacements', so they're all going to be + // destroyed. HloInstructions need to be detached from their operands before + // they're destroyed, otherwise they stick around in the operands' users lists + // and cause use-after-frees. + for (auto& kv : replacements) { + if (std::unique_ptr& new_instr = kv.second) { + new_instr->DetachFromOperands(); + } + } + return result; } diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 576c44a9f344160fd6184bf2bd590044676a27d6..c9782cc981ef067058a5b14d3d1fffdd3eb6b49b 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -24,6 +24,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" @@ -106,12 +107,6 @@ class HloComputation { // must have no users. Instruction is deallocated with this call. Status RemoveInstructionAndUnusedOperands(HloInstruction* instruction); - // Replace all uses of "instruction_to_replace" with "instruction". Also, if - // instruction_to_replace is the root of this computation then the root is set - // to "instruction". Does not remove "instruction_to_replace". - Status ReplaceUsesOfInstruction(HloInstruction* instruction_to_replace, - HloInstruction* instruction); - // Set the root of the computation to the given instruction. The instruction // must have already been added to the computation and have the same shape as // the result of the computation for non fusion computations. @@ -143,13 +138,46 @@ class HloComputation { void UniquifyName(NameUniquer* name_uniquer); // Return a string representation of the computation. - string ToString(int nested_level = 0) const; + string ToString(int nested_level = 0, + bool include_large_constants = false) const; // Returns a serialized representation of this computation. HloComputationProto ToProto() const; - const std::list>& instructions() const { - return instructions_; + // Creates a computation from the given proto. Arguments: + // + // module: the module which will contain the computation. The newly created + // computation is *not* added to the module, however. + // proto: the proto to convert from. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed computation + // calls. + // fusion_instruction: if non-null then the newly created computation will be + // constructed as a fused computation with this instruction as its fusion + // parent. + static StatusOr> CreateFromProto( + HloModule* module, const HloComputationProto& proto, + tensorflow::gtl::FlatMap* computation_map, + HloInstruction* fusion_instruction = nullptr); + + // Gets the instructions in this computation. + // + // The returned type is a range of HloInstruction*s, so you can iterate over + // it using a range-based for loop in the natural way: + // + // for (HloInstruction* instr : computation->instructions()) { ... } + // + tensorflow::gtl::iterator_range>::const_iterator>> + instructions() const { + return {MakeUnwrappingIterator(instructions_.begin()), + MakeUnwrappingIterator(instructions_.end())}; + } + tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + instructions() { + return {MakeUnwrappingIterator(instructions_.begin()), + MakeUnwrappingIterator(instructions_.end())}; } // Compute and return a post-order of the instructions in the computation. In @@ -243,7 +271,8 @@ class HloComputation { // via the root. The root instruction of the computation is visited last, and // the visitor's FinishVisit method is called once upon completion (with the // root instruction as the argument). - Status Accept(DfsHloVisitor* visitor) const; + template + Status Accept(DfsHloVisitorBase* visitor) const; // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == @@ -254,14 +283,31 @@ class HloComputation { // Visit every node in the computation in the given order. 'order' must // be a topological sort of all instructions in the computation. - Status AcceptOrdered(DfsHloVisitor* visitor, + template + Status AcceptOrdered(DfsHloVisitorBase* visitor, const std::vector& order) const; // Same as Accept() above, but the visitor is given as a function. - Status Accept(const FunctionVisitor::VisitorFunction& visitor_func) const; + Status Accept(const std::function& visitor_func); + Status Accept( + const std::function& visitor_func) const; // Returns a deep copy of this computation including all instructions. - std::unique_ptr Clone(const string& suffix = "clone"); + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr Clone(const string& suffix = "clone", + HloModule* module = nullptr); + + // Like Clone(), but if an instruction is present in replacement_map, we use + // the map's value to replace that instruction in the cloned computation. + // + // If replacements maps a key to nullptr, we remove that instruction from the + // new computation. + std::unique_ptr CloneWithReplacements( + std::unordered_map> + replacements, + HloModule* module = nullptr, const string& suffix = "clone"); // Returns true if the given instruction can be removed from the // computation. Instructions such as parameters and send/receive instructions @@ -285,8 +331,7 @@ class HloComputation { explicit HloComputation( const string& name, int parameter_count, std::vector>* instructions, - HloInstruction* root_instruction, - HloInstruction* fusion_instruction = nullptr); + HloInstruction* root_instruction, HloInstruction* fusion_instruction); // Internal helper for adding instructions. HloInstruction* AddInstructionInternal( @@ -332,11 +377,6 @@ class HloComputation { std::vector param_instructions_; - // Unique name generator for instruction identifiers. Instruction names should - // be unique per computation and this is enforced when instructions are added - // to the computation. - NameUniquer instruction_name_uniquer_; - TF_DISALLOW_COPY_AND_ASSIGN(HloComputation); }; diff --git a/tensorflow/compiler/xla/service/hlo_computation_test.cc b/tensorflow/compiler/xla/service/hlo_computation_test.cc index ccab7bf34862f3303db1331a87b5c70fdc3283ba..7b7588f4ba9aa622677db6f9d5022cc8cc029e04 100644 --- a/tensorflow/compiler/xla/service/hlo_computation_test.cc +++ b/tensorflow/compiler/xla/service/hlo_computation_test.cc @@ -310,7 +310,7 @@ TEST_F(HloComputationTest, DeepCopyArrayAtIndices) { } TEST_F(HloComputationTest, DeepCopyTupleAtIndices) { - // Test that DeepCopyInstruction properly copies elements of a a tuple as + // Test that DeepCopyInstruction properly copies elements of a tuple as // specified by the given indices. auto builder = HloComputation::Builder(TestName()); auto constant1 = builder.AddInstruction(HloInstruction::CreateConstant( diff --git a/tensorflow/compiler/xla/service/hlo_constant_folding.cc b/tensorflow/compiler/xla/service/hlo_constant_folding.cc index 58761cb4a487ef12b0cbeefd6820b415d724733c..53450991b6fad5b9651d9d23b55c908e6b68e5dd 100644 --- a/tensorflow/compiler/xla/service/hlo_constant_folding.cc +++ b/tensorflow/compiler/xla/service/hlo_constant_folding.cc @@ -41,10 +41,7 @@ StatusOr HloConstantFolding::Run(HloModule* module) { "HloConstantFolding::Run(), before:\n" + module->ToString()); bool changed = false; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module->MakeNonfusionComputations()) { for (auto instruction : computation->MakeInstructionPostOrder()) { // Skip dead code. if (instruction->user_count() == 0 && @@ -52,8 +49,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } // Skip Constant, Parameter, Reduce operation. - // TODO(b/35975797): Enable Reduce operation once arbitary computation are - // supported by the evaluator. + // TODO(b/35975797): Enable Reduce operation once arbitrary computation + // are supported by the evaluator. // TODO(b/64407269): Enable Tuple once the timeout issue is resolved. if (instruction->opcode() == HloOpcode::kParameter || instruction->opcode() == HloOpcode::kConstant || @@ -66,8 +63,8 @@ StatusOr HloConstantFolding::Run(HloModule* module) { continue; } - // Broadcasts dramatically increase the size of constants with is often - // detrimental to performance and memory capacity so do not fold + // Broadcasts dramatically increase the size of constants, which is often + // detrimental to performance and memory capacity, so do not fold // broadcasts. if (instruction->opcode() == HloOpcode::kBroadcast) { continue; diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 65725ca692fb3429106f5ed50f4a2c11bd46f54c..17ba2b673ac2db2060f720139bdc52ef1e72c98a 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -37,7 +37,7 @@ HloCostAnalysis::HloCostAnalysis(const ShapeSizeFunction& shape_size, const Properties& per_second_rates) : shape_size_(shape_size), per_second_rates_(per_second_rates) {} -Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { +Status HloCostAnalysis::Preprocess(const HloInstruction* hlo) { // Set current instruction cost values to reasonable default values. Each // handler can overwrite these values. In Postprocess, these values are // accumulated and written to the per-instruction maps. @@ -56,7 +56,7 @@ Status HloCostAnalysis::Preprocess(HloInstruction* hlo) { return Status::OK(); } -Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { +Status HloCostAnalysis::Postprocess(const HloInstruction* hlo) { if (current_should_compute_bottleneck_time_) { // Compute the time as the time of the bottleneck, i.e. the slowest property // given the per-second rate of each property. @@ -80,7 +80,8 @@ Status HloCostAnalysis::Postprocess(HloInstruction* hlo) { return Status::OK(); } -Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { +Status HloCostAnalysis::HandleElementwiseOp( + const HloInstruction* hlo_instruction) { const auto& shape = hlo_instruction->shape(); // For element-wise operations, the number of computations is the same as the // number of elements in the output shape. @@ -118,82 +119,64 @@ Status HloCostAnalysis::HandleElementwiseOp(HloInstruction* hlo_instruction) { } } -Status HloCostAnalysis::HandleElementwiseUnary(HloInstruction* hlo) { +Status HloCostAnalysis::HandleElementwiseUnary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleElementwiseBinary(HloInstruction* hlo) { +Status HloCostAnalysis::HandleElementwiseBinary(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, - HloInstruction* rhs) { +Status HloCostAnalysis::HandleCompare(const HloInstruction* compare) { return HandleElementwiseOp(compare); } -Status HloCostAnalysis::HandleClamp(HloInstruction* clamp, - HloInstruction* min_instruction, - HloInstruction* arg_instruction, - HloInstruction* max_instruction) { +Status HloCostAnalysis::HandleClamp(const HloInstruction* clamp) { return HandleElementwiseOp(clamp); } -Status HloCostAnalysis::HandleReducePrecision(HloInstruction* hlo) { +Status HloCostAnalysis::HandleReducePrecision(const HloInstruction* hlo) { return HandleElementwiseOp(hlo); } -Status HloCostAnalysis::HandleParameter(HloInstruction* parameter) { +Status HloCostAnalysis::HandleParameter(const HloInstruction*) { current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleConstant(HloInstruction* constant, - const Literal& literal) { +Status HloCostAnalysis::HandleConstant(const HloInstruction*) { current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) { +Status HloCostAnalysis::HandleGetTupleElement(const HloInstruction*) { // GetTupleElement forwards a pointer and does not touch each element in the // output. current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleSelect(HloInstruction* select, - HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) { +Status HloCostAnalysis::HandleSelect(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleReverse(HloInstruction* reverse, - HloInstruction* operand_instruction) { +Status HloCostAnalysis::HandleReverse(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleSlice(HloInstruction* slice, - HloInstruction* operand_instruction) { +Status HloCostAnalysis::HandleSlice(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* start_indices) { +Status HloCostAnalysis::HandleDynamicSlice(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleDynamicUpdateSlice( - HloInstruction* dynamic_update, HloInstruction* operand, - HloInstruction* update, HloInstruction* start_indices) { +Status HloCostAnalysis::HandleDynamicUpdateSlice(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status HloCostAnalysis::HandleTuple(const HloInstruction* tuple) { // The tuple instruction only gathers pointers from inputs (it doesn't iterate // through them). The memory touched is then only the size of the output // index table of the tuple. @@ -202,25 +185,21 @@ Status HloCostAnalysis::HandleTuple( return Status::OK(); } -Status HloCostAnalysis::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { +Status HloCostAnalysis::HandleConcatenate(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleConvert(HloInstruction* convert) { +Status HloCostAnalysis::HandleConvert(const HloInstruction* convert) { return HandleElementwiseOp(convert); } -Status HloCostAnalysis::HandleCopy(HloInstruction* copy) { +Status HloCostAnalysis::HandleCopy(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleDot(HloInstruction* dot, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction) { - const Shape& lhs_shape = lhs_instruction->shape(); - const Shape& rhs_shape = rhs_instruction->shape(); +Status HloCostAnalysis::HandleDot(const HloInstruction* dot) { + const Shape& lhs_shape = dot->operand(0)->shape(); + const Shape& rhs_shape = dot->operand(1)->shape(); // Count of elements along the reduction dimension (last dimension for the // rhs). int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1); @@ -240,21 +219,18 @@ Status HloCostAnalysis::HandleDot(HloInstruction* dot, return Status::OK(); } -Status HloCostAnalysis::HandleInfeed(HloInstruction* infeed) { +Status HloCostAnalysis::HandleInfeed(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleOutfeed(HloInstruction* outfeed) { +Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleMap( - HloInstruction* map, tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice /*static_operands*/) { +Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, - ProcessSubcomputation(function)); + ProcessSubcomputation(map->to_apply())); // Compute the cost of all elements for this Map operation. const int64 element_count = ShapeUtil::ElementsIn(map->shape()); @@ -266,9 +242,9 @@ Status HloCostAnalysis::HandleMap( return Status::OK(); } -Status HloCostAnalysis::HandleReduce( - HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, HloComputation* function) { +Status HloCostAnalysis::HandleReduce(const HloInstruction* reduce) { + auto arg = reduce->operand(0); + HloComputation* function = reduce->to_apply(); // Compute the cost of the user function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); @@ -284,10 +260,10 @@ Status HloCostAnalysis::HandleReduce( return Status::OK(); } -Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, - const Window& window, - HloComputation* function) { +Status HloCostAnalysis::HandleReduceWindow( + const HloInstruction* reduce_window) { + const Window& window = reduce_window->window(); + auto function = reduce_window->to_apply(); // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, ProcessSubcomputation(function)); @@ -310,7 +286,8 @@ Status HloCostAnalysis::HandleReduceWindow(HloInstruction* reduce_window, return Status::OK(); } -Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { +Status HloCostAnalysis::HandleSelectAndScatter( + const HloInstruction* instruction) { // Compute the properties of the select and scatter function. // Compute the properties of the reduction function. TF_ASSIGN_OR_RETURN(const Properties select_properties, @@ -342,70 +319,70 @@ Status HloCostAnalysis::HandleSelectAndScatter(HloInstruction* instruction) { return Status::OK(); } -Status HloCostAnalysis::HandleBitcast(HloInstruction* bitcast) { +Status HloCostAnalysis::HandleBitcast(const HloInstruction*) { // A bitcast does no computation and touches no memory. current_properties_[kBytesAccessedKey] = 0; return Status::OK(); } -Status HloCostAnalysis::HandleBroadcast(HloInstruction* broadcast) { +Status HloCostAnalysis::HandleBroadcast(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandlePad(HloInstruction* pad) { return Status::OK(); } +Status HloCostAnalysis::HandlePad(const HloInstruction*) { + return Status::OK(); +} -Status HloCostAnalysis::HandleSend(HloInstruction* send) { +Status HloCostAnalysis::HandleSend(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleRecv(HloInstruction* recv) { +Status HloCostAnalysis::HandleRecv(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleReshape(HloInstruction* reshape) { +Status HloCostAnalysis::HandleReshape(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleBatchNormTraining( - HloInstruction* batch_norm_training) { +Status HloCostAnalysis::HandleBatchNormTraining(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-training. return Status::OK(); } -Status HloCostAnalysis::HandleBatchNormInference( - HloInstruction* batch_norm_inference) { +Status HloCostAnalysis::HandleBatchNormInference(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-inference. return Status::OK(); } -Status HloCostAnalysis::HandleBatchNormGrad(HloInstruction* batch_norm_grad) { +Status HloCostAnalysis::HandleBatchNormGrad(const HloInstruction*) { // TODO(b/62294698): Implement cost analysis for batch-norm-grad. return Status::OK(); } -Status HloCostAnalysis::HandleTranspose(HloInstruction* transpose) { +Status HloCostAnalysis::HandleTranspose(const HloInstruction*) { return Status::OK(); } -Status HloCostAnalysis::HandleConvolution(HloInstruction* convolution, - HloInstruction* lhs_instruction, - HloInstruction* rhs_instruction, - const Window& window) { +Status HloCostAnalysis::HandleConvolution(const HloInstruction* convolution) { + auto rhs_instruction = convolution->operand(1); const auto& dnums = convolution->convolution_dimension_numbers(); const int64 output_features = - convolution->shape().dimensions(dnums.feature_dimension()); + convolution->shape().dimensions(dnums.output_feature_dimension()); // For each output element, we do one fma per element in the kernel at some // given output feature index. const int64 fmas_per_output_element = - ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features; + output_features > 0 + ? ShapeUtil::ElementsIn(rhs_instruction->shape()) / output_features + : 0; const int64 output_elements = ShapeUtil::ElementsIn(convolution->shape()); current_properties_[kFlopsKey] = output_elements * fmas_per_output_element * kFmaFlops; return Status::OK(); } -Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { +Status HloCostAnalysis::HandleCrossReplicaSum(const HloInstruction* crs) { // We assume 2 replicas, so that each output element is the sum of two input // elements. // @@ -415,8 +392,7 @@ Status HloCostAnalysis::HandleCrossReplicaSum(HloInstruction* crs) { return Status::OK(); } -Status HloCostAnalysis::HandleRng(HloInstruction* random, - RandomDistribution distribution) { +Status HloCostAnalysis::HandleRng(const HloInstruction* random) { // TODO(b/26346211): Implement better estimates for the RNG cost, since the // cost changes with the implementation and the distribution. For now, assume // the cost of each RNG is same as a transcendental operation. @@ -425,7 +401,7 @@ Status HloCostAnalysis::HandleRng(HloInstruction* random, return Status::OK(); } -Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { +Status HloCostAnalysis::HandleFusion(const HloInstruction* fusion) { // Compute the properties of the fused expression and attribute them to the // fusion node. Use a dummy shape_size to avoid any errors from trying to // calculate the size of a shape that does not have a layout, since nodes @@ -453,30 +429,26 @@ Status HloCostAnalysis::HandleFusion(HloInstruction* fusion) { return Status::OK(); } -Status HloCostAnalysis::HandleCall(HloInstruction* call) { +Status HloCostAnalysis::HandleCall(const HloInstruction* call) { TF_ASSIGN_OR_RETURN(current_properties_, ProcessSubcomputation(call->to_apply())); current_should_compute_bottleneck_time_ = false; return Status::OK(); } -Status HloCostAnalysis::HandleCustomCall( - HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) { +Status HloCostAnalysis::HandleCustomCall(const HloInstruction*) { return Unimplemented("Custom-call is not implemented for HLO cost analysis."); } -Status HloCostAnalysis::HandleSort(HloInstruction* sort, - HloInstruction* operand_instruction) { +Status HloCostAnalysis::HandleSort(const HloInstruction* sort) { // This assumes a comparison based N*log(N) algorithm. As for all ops, the // actual properties of the op depend on the backend implementation. - int64 elements = ShapeUtil::ElementsIn(operand_instruction->shape()); + int64 elements = ShapeUtil::ElementsIn(sort->operand(0)->shape()); current_properties_[kFlopsKey] = elements * tensorflow::Log2Ceiling(elements); return Status::OK(); } -Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { +Status HloCostAnalysis::HandleWhile(const HloInstruction* xla_while) { // Since the number of iterations of the while node will not always be // something that we can statically analyze, we cannot precisely compute the // cost of a while node. For now compute the cost of a single iteration. @@ -500,7 +472,7 @@ Status HloCostAnalysis::HandleWhile(HloInstruction* xla_while) { return Status::OK(); } -Status HloCostAnalysis::FinishVisit(HloInstruction* root) { +Status HloCostAnalysis::FinishVisit(const HloInstruction*) { return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index d71c2eccee349835c2f998e1774a4d292181c2e2..8074868e375541e424dbe17de8a3038880e41927 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -34,7 +34,7 @@ namespace xla { // the computation cost of the instruction, and the values are accumulated // during the traversal for the entire graph. We treat normal floating point // operations separately from transcendental operations. -class HloCostAnalysis : public DfsHloVisitor { +class HloCostAnalysis : public ConstDfsHloVisitor { public: // Each HLO is associated to a vector of properties with the indices given // below. Sub-classes can add further properties. @@ -49,83 +49,56 @@ class HloCostAnalysis : public DfsHloVisitor { using ShapeSizeFunction = std::function; explicit HloCostAnalysis(const ShapeSizeFunction& shape_size); - Status HandleElementwiseUnary(HloInstruction* hlo) override; - Status HandleElementwiseBinary(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; - Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleClamp(HloInstruction* clamp, HloInstruction* min, - HloInstruction* arg, HloInstruction* max) override; - Status HandleReducePrecision(HloInstruction* hlo) override; - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; - Status HandleSend(HloInstruction* send) override; - Status HandleRecv(HloInstruction* recv) override; - Status HandleConvert(HloInstruction* convert) override; - Status HandleCopy(HloInstruction* copy) override; - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override; - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override; - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleInfeed(HloInstruction* infeed) override; - Status HandleOutfeed(HloInstruction* outfeed) override; - Status HandleRng(HloInstruction* random, - RandomDistribution distribution) override; - Status HandleReverse(HloInstruction* reverse, - HloInstruction* operand) override; - Status HandleSort(HloInstruction* sort, HloInstruction* operand) override; - Status HandleParameter(HloInstruction* parameter) override; - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function_handle) override; - Status HandleBatchNormTraining(HloInstruction* batch_norm_training) override; + Status HandleElementwiseUnary(const HloInstruction* hlo) override; + Status HandleElementwiseBinary(const HloInstruction* hlo) override; + Status HandleConstant(const HloInstruction* constant) override; + Status HandleGetTupleElement( + const HloInstruction* get_tuple_element) override; + Status HandleSelect(const HloInstruction* select) override; + Status HandleCompare(const HloInstruction* compare) override; + Status HandleClamp(const HloInstruction* clamp) override; + Status HandleReducePrecision(const HloInstruction* hlo) override; + Status HandleConcatenate(const HloInstruction* concatenate) override; + Status HandleSend(const HloInstruction* send) override; + Status HandleRecv(const HloInstruction* recv) override; + Status HandleConvert(const HloInstruction* convert) override; + Status HandleCopy(const HloInstruction* copy) override; + Status HandleDot(const HloInstruction* dot) override; + Status HandleConvolution(const HloInstruction* convolution) override; + Status HandleCrossReplicaSum(const HloInstruction* crs) override; + Status HandleInfeed(const HloInstruction* infeed) override; + Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleRng(const HloInstruction* random) override; + Status HandleReverse(const HloInstruction* reverse) override; + Status HandleSort(const HloInstruction* sort) override; + Status HandleParameter(const HloInstruction* parameter) override; + Status HandleReduce(const HloInstruction* reduce) override; + Status HandleBatchNormTraining( + const HloInstruction* batch_norm_training) override; Status HandleBatchNormInference( - HloInstruction* batch_norm_inference) override; - Status HandleBatchNormGrad(HloInstruction* batch_norm_grad) override; - Status HandleFusion(HloInstruction* fusion) override; - Status HandleCall(HloInstruction* call) override; - Status HandleCustomCall(HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) override; - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override; - Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* start_indices) override; - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice static_operands) override; - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, const Window& window, - HloComputation* function) override; - Status HandleSelectAndScatter(HloInstruction* instruction) override; - Status HandleBitcast(HloInstruction* bitcast) override; - Status HandleBroadcast(HloInstruction* broadcast) override; - Status HandlePad(HloInstruction* pad) override; - Status HandleReshape(HloInstruction* reshape) override; - Status HandleTranspose(HloInstruction* transpose) override; - Status HandleWhile(HloInstruction* xla_while) override; - Status FinishVisit(HloInstruction* root) override; - - Status Preprocess(HloInstruction* hlo) override; - Status Postprocess(HloInstruction* hlo) override; + const HloInstruction* batch_norm_inference) override; + Status HandleBatchNormGrad(const HloInstruction* batch_norm_grad) override; + Status HandleFusion(const HloInstruction* fusion) override; + Status HandleCall(const HloInstruction* call) override; + Status HandleCustomCall(const HloInstruction* custom_call) override; + Status HandleSlice(const HloInstruction* slice) override; + Status HandleDynamicSlice(const HloInstruction* dynamic_slice) override; + Status HandleDynamicUpdateSlice( + const HloInstruction* dynamic_update_slice) override; + Status HandleTuple(const HloInstruction* tuple) override; + Status HandleMap(const HloInstruction* map) override; + Status HandleReduceWindow(const HloInstruction* reduce_window) override; + Status HandleSelectAndScatter(const HloInstruction* instruction) override; + Status HandleBitcast(const HloInstruction* bitcast) override; + Status HandleBroadcast(const HloInstruction* broadcast) override; + Status HandlePad(const HloInstruction* pad) override; + Status HandleReshape(const HloInstruction* reshape) override; + Status HandleTranspose(const HloInstruction* transpose) override; + Status HandleWhile(const HloInstruction* xla_while) override; + Status FinishVisit(const HloInstruction* root) override; + + Status Preprocess(const HloInstruction* hlo) override; + Status Postprocess(const HloInstruction* hlo) override; // Set the rates used to calculate the time taken by the computation. These // need to be set before visiting starts. @@ -174,7 +147,7 @@ class HloCostAnalysis : public DfsHloVisitor { const ShapeSizeFunction* shape_size = nullptr); // Utility function to handle all element-wise operations. - Status HandleElementwiseOp(HloInstruction* hlo_instruction); + Status HandleElementwiseOp(const HloInstruction* hlo_instruction); // Returns the default value if the key is not present in the // properties. Otherwise, returns the value that the key maps to from the diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc index 0a288a77ada840451915561b4b0865785b39ade7..0eaa21ef254e3461baaaca57503ab24ce35ac929 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis_test.cc @@ -169,7 +169,7 @@ TEST_F(HloCostAnalysisTest, MatrixMultiply) { TEST_F(HloCostAnalysisTest, Map) { ComputationBuilder builder(client_, "map"); auto input = builder.Parameter(0, ShapeUtil::MakeShape(F32, {10}), "in"); - auto result = builder.Map({input}, add_and_exp_); + auto result = builder.Map({input}, add_and_exp_, {0}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); @@ -286,7 +286,7 @@ TEST_F(HloCostAnalysisTest, FullyConnectedForward) { auto bias = builder.Parameter(2, ShapeUtil::MakeShape(F32, {20}), "bias"); // sigmoid(input * weight + bias) auto result = builder.Map( - {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_); + {builder.Add(builder.Dot(input, weight), bias, {1})}, sigmoid_, {0, 1}); // Run HLO cost analysis. auto hlo_module = BuildHloGraph(&builder); diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc index cdccacdd2d50c7e7d2a056c31f98aa72f10d4239..d35ba19a730555433099072c51ca5cf3774d4b99 100644 --- a/tensorflow/compiler/xla/service/hlo_cse.cc +++ b/tensorflow/compiler/xla/service/hlo_cse.cc @@ -51,7 +51,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { auto inst_it = computation->instructions().begin(); while (inst_it != computation->instructions().end()) { - HloInstruction* instruction = inst_it->get(); + HloInstruction* instruction = *inst_it; // Advance list iterator before loop body because iterator may be // invalidated due to deletion. @@ -77,7 +77,7 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { constants.emplace(shape_string, instruction); } else { // Match found, replace this instruction with the one in the multimap. - TF_CHECK_OK(computation->ReplaceUsesOfInstruction(instruction, match)); + TF_CHECK_OK(instruction->ReplaceAllUsesWith(match)); TF_CHECK_OK(computation->RemoveInstruction(instruction)); changed = true; } @@ -91,8 +91,8 @@ bool CombineConstants(HloComputation* computation, bool is_layout_sensitive) { StatusOr HloCSE::Run(HloModule* module) { bool changed = false; - for (auto& computation : module->computations()) { - changed |= CombineConstants(computation.get(), is_layout_sensitive_); + for (auto* computation : module->computations()) { + changed |= CombineConstants(computation, is_layout_sensitive_); std::list post_order = computation->MakeInstructionPostOrder(); @@ -121,8 +121,8 @@ StatusOr HloCSE::Run(HloModule* module) { // Replace all equivalent instructions with this instruction. for (HloInstruction* equivalent_instruction : equivalent_instructions) { - TF_RETURN_IF_ERROR(computation->ReplaceUsesOfInstruction( - equivalent_instruction, instruction)); + TF_RETURN_IF_ERROR( + equivalent_instruction->ReplaceAllUsesWith(instruction)); TF_RETURN_IF_ERROR( computation->RemoveInstruction(equivalent_instruction)); removed_instructions.insert(equivalent_instruction); diff --git a/tensorflow/compiler/xla/service/hlo_cse_test.cc b/tensorflow/compiler/xla/service/hlo_cse_test.cc index 417b7e82c32e0350994c9883ed26d8f972794396..7c4626e78a3e84c9723a9f8e39d56614c4fa25ce 100644 --- a/tensorflow/compiler/xla/service/hlo_cse_test.cc +++ b/tensorflow/compiler/xla/service/hlo_cse_test.cc @@ -67,7 +67,7 @@ TEST_F(HloCseTest, CombineTwoConstants) { EXPECT_TRUE(cse.Run(module.get()).ValueOrDie()); EXPECT_EQ(2, computation->instruction_count()); - HloInstruction* constant = computation->instructions().begin()->get(); + HloInstruction* constant = *computation->instructions().begin(); EXPECT_EQ(42.0f, constant->literal().Get({})); auto result = ExecuteAndTransfer(std::move(module), {}); diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index 213ff07b071574acea41b28f20a004016ee8f697..92261bce6270e3c37165c10ed804d036d2abb984 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -85,30 +85,27 @@ void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) { string HloDataflowAnalysis::ToString() const { string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n"); StrAppend(&out, " Instruction value sets:\n"); - for (const std::unique_ptr& computation : - module_->computations()) { - for (const std::unique_ptr& instruction : - computation->instructions()) { + for (const HloComputation* computation : module_->computations()) { + for (const HloInstruction* instruction : computation->instructions()) { StrAppend(&out, " ", instruction->name(), ":\n"); if (ShapeUtil::IsTuple(instruction->shape())) { - GetInstructionValueSet(instruction.get()) + GetInstructionValueSet(instruction) .ForEachElement([this, &instruction, &out]( const ShapeIndex& index, const HloValueSet& value_set) { StrAppend(&out, " tuple index ", index.ToString(), ":\n"); for (const HloValue* value : value_set.values()) { - StrAppend( - &out, " ", value->ToShortString(), - ValueIsDefinedAt(instruction.get(), index) ? " (def)" : "", - "\n"); + StrAppend(&out, " ", value->ToShortString(), + ValueIsDefinedAt(instruction, index) ? " (def)" : "", + "\n"); } }); } else { const HloValueSet& top_level_value_set = - GetValueSet(instruction.get(), /*index=*/{}); + GetValueSet(instruction, /*index=*/{}); for (const HloValue* value : top_level_value_set.values()) { StrAppend(&out, " ", value->ToShortString(), - ValueIsDefinedAt(instruction.get()) ? " (def)" : "", "\n"); + ValueIsDefinedAt(instruction) ? " (def)" : "", "\n"); } } } @@ -513,26 +510,21 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet( } Status HloDataflowAnalysis::InitializeInstructionValueSets() { - for (const std::unique_ptr& computation : - module_->computations()) { - const CallGraphNode& call_graph_node = - call_graph_->GetNode(computation.get()); - - for (const std::unique_ptr& instruction : - computation->instructions()) { + for (const HloComputation* computation : module_->computations()) { + const CallGraphNode& call_graph_node = call_graph_->GetNode(computation); + for (HloInstruction* instruction : computation->instructions()) { // Create an empty shape tree. value_sets_.emplace(std::piecewise_construct, - std::forward_as_tuple(instruction.get()), + std::forward_as_tuple(instruction), std::forward_as_tuple(instruction->shape())); // Lambda to set the value set to define all values in the output of the // instruction. auto define_all_values = [this, &instruction](bool is_phi = false) { - for (auto& pair : GetInstructionValueSet(instruction.get())) { + for (auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; - HloValue* value = - NewHloValue(instruction.get(), index, /*is_phi=*/false); - GetValueSet(instruction.get(), index).AddValue(value); + HloValue* value = NewHloValue(instruction, index, /*is_phi=*/false); + GetValueSet(instruction, index).AddValue(value); } }; @@ -541,8 +533,8 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { // the instruction (or from cross-computation dataflow). auto define_top_level_only = [this, &instruction]() { HloValue* value = - NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false); - GetValueSet(instruction.get(), /*index=*/{}).AddValue(value); + NewHloValue(instruction, /*index=*/{}, /*is_phi=*/false); + GetValueSet(instruction, /*index=*/{}).AddValue(value); }; switch (instruction->opcode()) { @@ -619,18 +611,16 @@ StatusOr> HloDataflowAnalysis::Run( dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions); // Add in positions to all values. - for (const std::unique_ptr& computation : - module->computations()) { - for (const std::unique_ptr& instruction : - computation->instructions()) { + for (const HloComputation* computation : module->computations()) { + for (HloInstruction* instruction : computation->instructions()) { for (const auto& pair : - dataflow_analysis->GetInstructionValueSet(instruction.get())) { + dataflow_analysis->GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; const HloValueSet& value_set = pair.second; for (const HloValue* value : value_set.values()) { - if (value->defining_instruction() != instruction.get()) { + if (value->defining_instruction() != instruction) { dataflow_analysis->GetValue(value->id()) - .AddPosition(instruction.get(), index); + .AddPosition(instruction, index); } } } @@ -670,10 +660,10 @@ Status HloDataflowAnalysis::Verify() const { // appears in the value's positions(). for (const auto& computation : module_->computations()) { for (const auto& instruction : computation->instructions()) { - for (const auto& pair : GetInstructionValueSet(instruction.get())) { + for (const auto& pair : GetInstructionValueSet(instruction)) { const ShapeIndex& index = pair.first; const HloValueSet& value_set = pair.second; - const HloPosition position{instruction.get(), index}; + const HloPosition position{instruction, index}; for (const HloValue* value : value_set.values()) { TF_RET_CHECK(std::find(value->positions().begin(), value->positions().end(), diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc index 5b2c57da4ff3a1f887f777c3304893d950b3d3a9..a4921232f5848dbe1789c4c641e2b0ba3c1848bb 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.cc +++ b/tensorflow/compiler/xla/service/hlo_dce.cc @@ -37,10 +37,7 @@ namespace xla { StatusOr HloDCE::Run(HloModule* module) { bool changed = false; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module->MakeNonfusionComputations()) { std::unordered_set live_instructions; TF_RETURN_IF_ERROR(computation->root_instruction()->Accept( [&live_instructions](HloInstruction* instruction) { @@ -52,11 +49,11 @@ StatusOr HloDCE::Run(HloModule* module) { // into a separate list first to avoid problems with iterating through the // computation's instruction while simultaneously removing instructions. std::vector dead_roots; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->user_count() == 0 && - live_instructions.count(instruction.get()) == 0 && - computation->IsRemovable(instruction.get())) { - dead_roots.push_back(instruction.get()); + live_instructions.count(instruction) == 0 && + computation->IsRemovable(instruction)) { + dead_roots.push_back(instruction); } } @@ -67,6 +64,29 @@ StatusOr HloDCE::Run(HloModule* module) { } } + // Now DCE HloComputations. First, collect the computations that are + // referenced by some remaining instruction. + std::unordered_set live_computations; + if (HloComputation* entry_computation = module->entry_computation()) { + live_computations.insert(entry_computation); + } + for (auto* computation : module->MakeComputationPostOrder()) { + for (auto* instruction : computation->instructions()) { + for (auto* subcomp : instruction->called_computations()) { + live_computations.insert(subcomp); + } + } + } + + // Remove dead computations. + std::list computations = module->MakeComputationPostOrder(); + for (auto* computation : computations) { + if (live_computations.count(computation) == 0) { + TF_RETURN_IF_ERROR(module->RemoveEmbeddedComputation(computation)); + changed = true; + } + } + return changed; } diff --git a/tensorflow/compiler/xla/service/hlo_dce.h b/tensorflow/compiler/xla/service/hlo_dce.h index fca3fa0f58b7c5929c6ffa6c2d8ae6f76660b380..4e244494d6f98c48f4376bd762f116b9a9c2084d 100644 --- a/tensorflow/compiler/xla/service/hlo_dce.h +++ b/tensorflow/compiler/xla/service/hlo_dce.h @@ -24,10 +24,15 @@ limitations under the License. namespace xla { -// HLO pass which removes all dead instructions from each computation in the -// module. An instruction is dead if it is not reachable from the root. This -// pass does not remove dead parameter instructions as parameter instructions -// cannot be deleted, nor does the pass remove dead computations. +// HLO pass which removes dead instructions from each computation in the module +// and removes dead computations from the module. +// +// An instruction is dead if it is not reachable from the root. A computation is +// dead if it is not the entry computation of the module and it is not reachable +// from the entry computation. +// +// This pass does not remove dead parameter instructions, as parameter +// instructions cannot be deleted. class HloDCE : public HloPassInterface { public: ~HloDCE() override {} diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc index 8fdc2fe2c51b5b7386cf99c4e148f53e25b9590d..d54b9a27087a42fd23eab0bd06e8deaca567312b 100644 --- a/tensorflow/compiler/xla/service/hlo_dce_test.cc +++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc @@ -43,12 +43,9 @@ class HloDceTest : public HloTestBase { // Returns whether the given instruction exists in the given computation. bool HasInstruction(const HloComputation& computation, const HloInstruction* instruction) { - for (auto& inst : computation.instructions()) { - if (inst.get() == instruction) { - return true; - } - } - return false; + return std::find(computation.instructions().begin(), + computation.instructions().end(), + instruction) != computation.instructions().end(); } }; @@ -302,5 +299,93 @@ TEST_F(HloDceTest, CalledComputationWithNestedSideEffect) { EXPECT_TRUE(HasInstruction(*computation, live_call)); } +TEST_F(HloDceTest, RemoveDeadSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation. + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We should have DCE'ed the reduction computation along with the reduction + // instruction. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 1); +} + +TEST_F(HloDceTest, KeepUsedSubcomputation) { + auto module = CreateNewModule(); + HloComputation::Builder builder(TestName()); + + HloComputation::Builder subcomp_builder("reduction_subcomp"); + { + auto* param0 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "param0")); + auto* param1 = + subcomp_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "param1")); + subcomp_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, param0, param1)); + } + auto reduce_subcomp = module->AddEmbeddedComputation(subcomp_builder.Build()); + + // Create a dead reduce instruction. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {100}), "param0")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + // Add another instruction as the root of the computation that also uses + // reduce_subcomp. + builder.AddInstruction(HloInstruction::CreateReduce( + ShapeUtil::MakeShape(F32, {1}), + builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {100}), "param1")), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + /*dimensions_to_reduce=*/{0}, reduce_subcomp)); + + module->AddEntryComputation(builder.Build()); + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); + + HloDCE dce; + EXPECT_TRUE(dce.Run(module.get()).ValueOrDie()); + + // We shouldn't have DCE'ed reduce_subcomp, even though we removed one of + // its users. + EXPECT_EQ(module->MakeComputationPostOrder().size(), 2); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc index e1e43ec60f910dc979f0a20ad450b3bbf38b8deb..88b77ccdd03eb129f81cfa1da430e882ea569df4 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc @@ -50,6 +50,12 @@ namespace xla { namespace { +template +struct is_complex_t : public std::false_type {}; + +template <> +struct is_complex_t : public std::true_type {}; + template StatusOr> Compare(const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, @@ -101,6 +107,37 @@ StatusOr> Compare(const Shape& shape, HloOpcode opcode, return std::move(result); } +template <> +StatusOr> Compare( + const Shape& shape, HloOpcode opcode, const Literal& lhs_literal, + const Literal& rhs_literal) { + std::function compare_op; + switch (opcode) { + case HloOpcode::kEq: + compare_op = [](complex64 lhs_el, complex64 rhs_el) { + return lhs_el == rhs_el; + }; + break; + case HloOpcode::kNe: + compare_op = [](complex64 lhs_el, complex64 rhs_el) { + return lhs_el != rhs_el; + }; + break; + default: + LOG(FATAL) << "unhandled HLO opcode for conversion to Comparison: " + << HloOpcodeString(opcode); + } + + auto result = Literal::CreateFromShape(shape); + TF_RETURN_IF_ERROR(result->Populate( + [&](tensorflow::gtl::ArraySlice multi_index) { + return compare_op(lhs_literal.Get(multi_index), + rhs_literal.Get(multi_index)); + })); + + return std::move(result); +} + template StatusOr> ElementWiseUnaryOpImpl( HloInstruction* instruction, @@ -138,7 +175,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo_instruction) override { return Unimplemented("unhandled HLO ops for HloEvaluator: %s.", HloOpcodeString(hlo_instruction->opcode()).c_str()); - }; + } // TODO(b/35950897): many of the stl functions used in the handlers are not // overloaded for every XLA primitive types. @@ -146,7 +183,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template ::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return elem_operand; @@ -156,8 +193,9 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { template < typename NativeT, - typename std::enable_if::value>::type* = nullptr> - Status HandleAbs(HloInstruction* abs, HloInstruction* operand) { + typename std::enable_if::value || + is_complex_t::value>::type* = nullptr> + Status HandleAbs(HloInstruction* abs) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[abs], ElementWiseUnaryOp(abs, [](NativeT elem_operand) { return std::abs(elem_operand); @@ -165,11 +203,14 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleAbs(HloInstruction* abs, HloInstruction* operand) override { - return HandleAbs(abs, operand); + Status HandleAbs(HloInstruction* abs) override { + return HandleAbs(abs); } - Status HandleRound(HloInstruction* round) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[round], ElementWiseUnaryOp(round, [](ReturnT elem_operand) { return std::round(elem_operand); @@ -177,6 +218,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRound(HloInstruction* round) { + return InvalidArgument("Unsupported type for Round"); + } + + Status HandleRound(HloInstruction* round) override { + return HandleRound(round); + } + Status HandleBroadcast(HloInstruction* broadcast) override { parent_->evaluated_[broadcast] = Literal::CreateFromShape(broadcast->shape()); @@ -205,15 +257,29 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return operand_to_broadcast.Get(broadcast_indices); }); - }; + } - Status HandleCeil(HloInstruction* ceil, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[ceil], ElementWiseUnaryOp(ceil, [](ReturnT elem_operand) { return std::ceil(elem_operand); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleCeil(HloInstruction* ceil) { + return InvalidArgument("Unsupported type for Ceil"); + } + + Status HandleCeil(HloInstruction* ceil) override { + return HandleCeil(ceil); + } Status HandleConvert(HloInstruction* convert) override { const HloInstruction* operand = convert->operand(0); @@ -231,165 +297,353 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleExp(HloInstruction* exp, HloInstruction* operand) override { + Status HandleExp(HloInstruction* exp) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[exp], ElementWiseUnaryOp(exp, [](ReturnT elem_operand) { return std::exp(elem_operand); })); return Status::OK(); - }; + } - Status HandleFloor(HloInstruction* floor, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[floor], ElementWiseUnaryOp(floor, [](ReturnT elem_operand) { return std::floor(elem_operand); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleFloor(HloInstruction* floor) { + return InvalidArgument("Unsupported type for Floor"); + } + + Status HandleFloor(HloInstruction* floor) override { + return HandleFloor(floor); + } - Status HandleLog(HloInstruction* log, HloInstruction* operand) override { + Status HandleLog(HloInstruction* log) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[log], ElementWiseUnaryOp(log, [](ReturnT elem_operand) { return std::log(elem_operand); })); return Status::OK(); - }; + } - Status HandleLogicalNot(HloInstruction* logical_not, - HloInstruction* operand) override { - TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_not], - ElementWiseUnaryOp(logical_not, - [](ReturnT elem_operand) { return !elem_operand; })); + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[not_], + ElementWiseUnaryOp(not_, [](ReturnT elem_operand) { + return !elem_operand; + })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleNot(HloInstruction* not_) { + return InvalidArgument("Unsupported type for Not"); + } - Status HandleNegate(HloInstruction* negate, - HloInstruction* operand) override { + Status HandleNot(HloInstruction* not_) override { + return HandleNot(not_); + } + + Status HandleNegate(HloInstruction* negate) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[negate], ElementWiseUnaryOp(negate, [](ReturnT elem_operand) { return -elem_operand; })); return Status::OK(); - }; + } - Status HandleSign(HloInstruction* sign, HloInstruction* operand) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { return (ReturnT(0) < elem_operand) - (elem_operand < ReturnT(0)); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleSign(HloInstruction* sign) { + TF_ASSIGN_OR_RETURN(parent_->evaluated_[sign], + ElementWiseUnaryOp(sign, [](ReturnT elem_operand) { + auto abs_val = std::abs(elem_operand); + return 0 == abs_val ? ReturnT(0) + : elem_operand / abs_val; + })); + return Status::OK(); + } + + Status HandleSign(HloInstruction* sign) override { + return HandleSign(sign); + } - Status HandleTanh(HloInstruction* tanh, HloInstruction* operand) override { + Status HandleTanh(HloInstruction* tanh) override { TF_ASSIGN_OR_RETURN(parent_->evaluated_[tanh], ElementWiseUnaryOp(tanh, [](ReturnT elem_operand) { return std::tanh(elem_operand); })); return Status::OK(); - }; + } - Status HandleMultiply(HloInstruction* multiply, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleMultiply(HloInstruction* multiply) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[multiply], ElementWiseBinaryOp(multiply, [](ReturnT lhs_elem, ReturnT rhs_elem) { return lhs_elem * rhs_elem; })); return Status::OK(); - }; + } - Status HandleSubtract(HloInstruction* subtract, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleSubtract(HloInstruction* subtract) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[subtract], ElementWiseBinaryOp(subtract, [](ReturnT lhs_elem, ReturnT rhs_elem) { return lhs_elem - rhs_elem; })); return Status::OK(); - }; + } - Status HandleAdd(HloInstruction* add, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleAdd(HloInstruction* add) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[add], ElementWiseBinaryOp(add, [](ReturnT lhs_elem, ReturnT rhs_elem) { return lhs_elem + rhs_elem; })); return Status::OK(); - }; + } - Status HandleDivide(HloInstruction* divide, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleDivide(HloInstruction* divide) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[divide], ElementWiseBinaryOp(divide, [](ReturnT lhs_elem, ReturnT rhs_elem) { return lhs_elem / rhs_elem; })); return Status::OK(); - }; + } - Status HandleMaximum(HloInstruction* maximum) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[maximum], ElementWiseBinaryOp(maximum, [](ReturnT lhs, ReturnT rhs) { return std::fmax(lhs, rhs); })); return Status::OK(); - }; + } - Status HandleMinimum(HloInstruction* minimum) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMaximum(HloInstruction* maximum) { + return InvalidArgument("Unsupported type for Maximum"); + } + + Status HandleMaximum(HloInstruction* maximum) override { + return HandleMaximum(maximum); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[minimum], ElementWiseBinaryOp(minimum, [](ReturnT lhs_el, ReturnT rhs_el) { return std::fmin(lhs_el, rhs_el); })); return Status::OK(); - }; + } - Status HandlePower(HloInstruction* power, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleMinimum(HloInstruction* minimum) { + return InvalidArgument("Unsupported type for Minimum"); + } + + Status HandleMinimum(HloInstruction* minimum) override { + return HandleMinimum(minimum); + } + + Status HandlePower(HloInstruction* power) override { TF_ASSIGN_OR_RETURN( parent_->evaluated_[power], ElementWiseBinaryOp(power, [](ReturnT lhs_el, ReturnT rhs_el) { return std::pow(lhs_el, rhs_el); })); return Status::OK(); - }; + } - Status HandleRemainder(HloInstruction* remainder, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { TF_ASSIGN_OR_RETURN( parent_->evaluated_[remainder], ElementWiseBinaryOp(remainder, [](ReturnT lhs_el, ReturnT rhs_el) { return std::fmod(lhs_el, rhs_el); })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleRemainder(HloInstruction* remainder) { + return InvalidArgument("Unsupported type for Remainder"); + } + + Status HandleRemainder(HloInstruction* remainder) override { + return HandleRemainder(remainder); + } - Status HandleLogicalAnd(HloInstruction* logical_and, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_and], - ElementWiseBinaryOp(logical_and, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[and_], + ElementWiseBinaryOp(and_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el && rhs_el; })); return Status::OK(); - }; + } - Status HandleLogicalOr(HloInstruction* logical_or, HloInstruction* lhs, - HloInstruction* rhs) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleAnd(HloInstruction* and_) { + return InvalidArgument("Unsupported type for And"); + } + + Status HandleAnd(HloInstruction* and_) override { + return HandleAnd(and_); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { TF_ASSIGN_OR_RETURN( - parent_->evaluated_[logical_or], - ElementWiseBinaryOp(logical_or, [](ReturnT lhs_el, ReturnT rhs_el) { + parent_->evaluated_[or_], + ElementWiseBinaryOp(or_, [](ReturnT lhs_el, ReturnT rhs_el) { return lhs_el || rhs_el; })); return Status::OK(); - }; + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleOr(HloInstruction* or_) { + return InvalidArgument("Unsupported type for Or"); + } + + Status HandleOr(HloInstruction* or_) override { + return HandleOr(or_); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftLeft(HloInstruction* shl) { + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shl], + ElementWiseBinaryOp(shl, [](NativeT lhs_elem, NativeT rhs_elem) { + return lhs_elem << rhs_elem; + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftLeft(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftLeft"); + } + + Status HandleShiftLeft(HloInstruction* shl) override { + return HandleShiftLeft(shl); + } + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightArithmetic(HloInstruction* shr) { + typedef typename std::make_signed::type SignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } - Status HandleClamp(HloInstruction* clamp, HloInstruction* min, - HloInstruction* arg, HloInstruction* max) override { + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightArithmetic(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightArithmetic"); + } + + Status HandleShiftRightArithmetic(HloInstruction* shra) override { + return HandleShiftRightArithmetic(shra); + } + + template ::value && + !std::is_same::value>::type* = nullptr> + Status HandleShiftRightLogical(HloInstruction* shr) { + typedef typename std::make_unsigned::type UnsignedT; + TF_ASSIGN_OR_RETURN( + parent_->evaluated_[shr], + ElementWiseBinaryOp(shr, [](NativeT lhs_elem, NativeT rhs_elem) { + return static_cast(static_cast(lhs_elem) >> + rhs_elem); + })); + return Status::OK(); + } + + template ::value || + std::is_same::value>::type* = + nullptr> + Status HandleShiftRightLogical(HloInstruction*) { + return InvalidArgument("Unsupported type for ShiftRightLogical"); + } + + Status HandleShiftRightLogical(HloInstruction* shrl) override { + return HandleShiftRightLogical(shrl); + } + + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction* clamp) { std::function clamp_op = [](ReturnT low, ReturnT high, ReturnT value) { return std::fmax(low, std::fmin(value, high)); @@ -397,11 +651,20 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[clamp], ElementWiseTernaryOp(clamp, std::move(clamp_op))); return Status::OK(); - }; + } - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override { + template < + typename NativeT, + typename std::enable_if::value>::type* = nullptr> + Status HandleClamp(HloInstruction*) { + return InvalidArgument("Unsupported type for Clamp"); + } + + Status HandleClamp(HloInstruction* clamp) override { + return HandleClamp(clamp); + } + + Status HandleSelect(HloInstruction* select) override { CHECK(!ShapeUtil::IsTuple(select->shape())); std::function select_op = [](bool pred, ReturnT on_true, ReturnT on_false) { @@ -413,13 +676,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_ASSIGN_OR_RETURN(parent_->evaluated_[select], ElementWiseTernaryOp(select, std::move(select_op))); return Status::OK(); - }; + } - Status HandleReverse(HloInstruction* reverse, - HloInstruction* operand) override { + Status HandleReverse(HloInstruction* reverse) override { const auto result_shape = reverse->shape(); const auto reverse_dimensions = reverse->dimensions(); + auto operand = reverse->operand(0); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferReverseShape(operand->shape(), reverse_dimensions)); @@ -443,10 +706,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reverse] = std::move(result); return Status::OK(); - }; + } - Status HandleConvolution(HloInstruction* conv, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override { + Status HandleConvolution(HloInstruction* conv) override { + auto lhs = conv->operand(0); + auto rhs = conv->operand(1); + const auto& window = conv->window(); const Shape& result_shape = conv->shape(); const Shape& lhs_shape = lhs->shape(); const Shape& rhs_shape = rhs->shape(); @@ -461,7 +726,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const auto& dnums = conv->convolution_dimension_numbers(); const int64 num_spatial_dims = dnums.spatial_dimensions_size(); CHECK_EQ(num_spatial_dims, dnums.kernel_spatial_dimensions_size()); - CHECK_GE(num_spatial_dims, 1); + CHECK_GE(num_spatial_dims, 0); CHECK_EQ(window.dimensions_size(), num_spatial_dims); const auto lhs_rank = ShapeUtil::Rank(lhs_shape); @@ -481,14 +746,17 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs); const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs); - // Dimension number applicable for both input (lhs), and output. - const int64 batch_dim = dnums.batch_dimension(); - const int64 z_dim = dnums.feature_dimension(); + // Dimension number applicable for input (lhs). + const int64 input_batch_dim = dnums.input_batch_dimension(); + const int64 input_z_dim = dnums.input_feature_dimension(); // Dimension number applicable for kernel (rhs). const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension(); const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension(); + // Dimension number applicable for output. + const int64 output_batch_dim = dnums.output_batch_dimension(); + const int64 output_z_dim = dnums.output_feature_dimension(); - const int64 z_size = ShapeUtil::GetDimension(lhs_shape, z_dim); + const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim); std::vector window_dimension_sizes; for (auto i : dnums.kernel_spatial_dimensions()) { @@ -509,13 +777,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { std::fill(rhs_index.begin(), rhs_index.end(), 0); std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0); - lhs_index[batch_dim] = out_index[batch_dim]; - rhs_index[kernel_output_z_dim] = out_index[z_dim]; + lhs_index[input_batch_dim] = out_index[output_batch_dim]; + rhs_index[kernel_output_z_dim] = out_index[output_z_dim]; // Convolve input feature with kernel. do { for (int64 iz = 0; iz < z_size; ++iz) { - lhs_index[z_dim] = iz; + lhs_index[input_z_dim] = iz; rhs_index[kernel_input_z_dim] = iz; // Find corresponding spatial dimension index for input (lhs). @@ -563,10 +831,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[conv] = std::move(result); return Status::OK(); - }; + } - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleDot(HloInstruction* dot) override { + auto lhs = dot->operand(0); + auto rhs = dot->operand(1); CHECK(ShapeUtil::IsArray(dot->shape())); CHECK(ShapeUtil::IsArray(lhs->shape())); CHECK(ShapeUtil::IsArray(rhs->shape())); @@ -630,7 +899,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[dot] = std::move(result); return Status::OK(); - }; + } Status HandlePad(HloInstruction* pad) override { CHECK(!ShapeUtil::IsTuple(pad->operand(0)->shape())); @@ -699,11 +968,11 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[pad] = std::move(result); return Status::OK(); - }; + } - Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* start_indices) override { + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { + auto operand = dynamic_slice->operand(0); + auto start_indices = dynamic_slice->operand(1); auto result_shape = dynamic_slice->shape(); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferDynamicSliceShape( @@ -752,12 +1021,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return Status::OK(); - }; + } - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) override { + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { + auto operand = dynamic_update_slice->operand(0); + auto update = dynamic_update_slice->operand(1); + auto start_indices = dynamic_update_slice->operand(2); auto result_shape = dynamic_update_slice->shape(); TF_ASSIGN_OR_RETURN( auto inferred_return_shape, @@ -808,12 +1078,13 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } return Status::OK(); - }; + } - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override { + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); + tensorflow::gtl::ArraySlice dimensions(reduce->dimensions()); + HloComputation* function = reduce->to_apply(); TF_RET_CHECK(ShapeUtil::Rank(reduce->shape()) == ShapeUtil::Rank(arg->shape()) - dimensions.size()); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, @@ -896,11 +1167,12 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reduce] = std::move(result); return Status::OK(); - }; + } - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, const Window& window, - HloComputation* function) override { + Status HandleReduceWindow(HloInstruction* reduce_window) override { + auto operand = reduce_window->operand(0); + const Window& window = reduce_window->window(); + HloComputation* function = reduce_window->to_apply(); TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferReduceWindowShape( @@ -983,9 +1255,10 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { parent_->evaluated_[reduce_window] = std::move(result); return Status::OK(); - }; + } - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override { + Status HandleSlice(HloInstruction* slice) override { + auto operand = slice->operand(0); const Shape& shape = slice->shape(); TF_ASSIGN_OR_RETURN(auto inferred_return_shape, ShapeInference::InferSliceShape( @@ -1012,7 +1285,7 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { TF_RETURN_IF_ERROR(result->Populate(func)); parent_->evaluated_[slice] = std::move(result); return Status::OK(); - }; + } private: template @@ -1155,32 +1428,33 @@ class HloEvaluator::TypedVisitor : public DfsHloVisitorWithDefault { } HloEvaluator* parent_; -}; // namespace xla +}; // class HloEvaluator::TypedVisitor HloEvaluator::HloEvaluator() { typed_visitors_[PRED] = MakeUnique>(this); typed_visitors_[U8] = MakeUnique>(this); typed_visitors_[U16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: U16."); + return Unimplemented("HloEvaluator: unhandled primitive type: U16."); }); typed_visitors_[U32] = MakeUnique>(this); typed_visitors_[U64] = MakeUnique>(this); typed_visitors_[S8] = MakeUnique>(this); typed_visitors_[S16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: S16."); + return Unimplemented("HloEvaluator: unhandled primitive type: S16."); }); typed_visitors_[S32] = MakeUnique>(this); typed_visitors_[S64] = MakeUnique>(this); typed_visitors_[F16] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: F16."); + return Unimplemented("HloEvaluator: unhandled primitive type: F16."); }); typed_visitors_[F32] = MakeUnique>(this); typed_visitors_[F64] = MakeUnique>(this); + typed_visitors_[C64] = MakeUnique>(this); typed_visitors_[TUPLE] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: TUPLE."); + return Unimplemented("HloEvaluator: unhandled primitive type: TUPLE."); }); typed_visitors_[OPAQUE] = MakeUnique([](HloInstruction*) { - return Unimplemented("unhandled primitive type: OPAQUE."); + return Unimplemented("HloEvaluator: unhandled primitive type: OPAQUE."); }); } @@ -1241,8 +1515,14 @@ StatusOr> HloEvaluator::Evaluate( StatusOr> HloEvaluator::Evaluate( HloInstruction* instruction) { - TF_RET_CHECK(hlo_query::AllOperandsAreConstants(*instruction)); - TF_RET_CHECK(instruction->opcode() != HloOpcode::kParameter); + if (instruction->opcode() == HloOpcode::kParameter) { + return tensorflow::errors::FailedPrecondition( + "Cannot evaluate a parameter."); + } + if (!hlo_query::AllOperandsAreConstants(*instruction)) { + return tensorflow::errors::FailedPrecondition( + "Not all operands are constants."); + } TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(instruction->shape())); arg_literals_.clear(); @@ -1265,6 +1545,39 @@ std::unique_ptr HloEvaluator::TryEvaluate( return result_or.ConsumeValueOrDie(); } +StatusOr> HloEvaluator::EvaluateWithSubstitutions( + const HloInstruction* instruction, + const std::unordered_map& + substitutions) { + std::vector> owned_operands; + for (const HloInstruction* operand : instruction->operands()) { + auto it = substitutions.find(operand); + if (it == substitutions.end()) { + owned_operands.push_back(operand->Clone()); + } else { + owned_operands.push_back( + HloInstruction::CreateConstant(it->second->CloneToUnique())); + } + } + + std::vector operands; + for (auto& operand : owned_operands) { + operands.push_back(operand.get()); + } + + std::unique_ptr cloned_instruction = + instruction->CloneWithNewOperands(instruction->shape(), operands); + auto result = Evaluate(cloned_instruction.get()); + + // Clean up our cloned instructions before returning. + cloned_instruction->DetachFromOperands(); + for (auto& operand : owned_operands) { + operand->DetachFromOperands(); + } + + return result; +} + Status HloEvaluator::HandleParameter(HloInstruction* parameter) { const Literal* input_literal = arg_literals_[parameter->parameter_number()]; VLOG(2) << "Parameter evaluated to: " << input_literal->ToString(); @@ -1274,10 +1587,7 @@ Status HloEvaluator::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status HloEvaluator::HandleConstant(HloInstruction* constant, - const Literal& literal) { - return Status::OK(); -} +Status HloEvaluator::HandleConstant(HloInstruction*) { return Status::OK(); } Status HloEvaluator::HandleReshape(HloInstruction* reshape) { TF_ASSIGN_OR_RETURN( @@ -1293,9 +1603,9 @@ Status HloEvaluator::HandleTranspose(HloInstruction* transpose) { return Status::OK(); } -Status HloEvaluator::HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) { +Status HloEvaluator::HandleConcatenate(HloInstruction* concatenate) { + tensorflow::gtl::ArraySlice operands( + concatenate->operands()); // The result concatenate dimension is going to be the sum of all concatenate // dimensions of the operands taking part of the operation. const Shape& reference_shape = operands[0]->shape(); @@ -1335,8 +1645,8 @@ Status HloEvaluator::HandleConcatenate( return Status::OK(); } -Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) { +Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite) { + auto operand = is_finite->operand(0); if (!ShapeUtil::ElementIsFloating(operand->shape())) { return InvalidArgument( "expected element type in shape to be float for IsFinite op, got: %s", @@ -1370,8 +1680,10 @@ Status HloEvaluator::HandleIsFinite(HloInstruction* is_finite, return Status::OK(); } -Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) { +Status HloEvaluator::HandleCompare(HloInstruction* compare) { + HloOpcode opcode = compare->opcode(); + auto lhs = compare->operand(0); + auto rhs = compare->operand(1); // TODO(b/35950897, b/27796129): add DCHECK back once implicit broadcast is // removed. if (!(ShapeUtil::SameDimensions(compare->shape(), rhs->shape()) && @@ -1442,6 +1754,11 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, evaluated_[compare], Compare(compare->shape(), opcode, lhs_literal, rhs_literal)); } break; + case C64: { + TF_ASSIGN_OR_RETURN(evaluated_[compare], + Compare(compare->shape(), opcode, + lhs_literal, rhs_literal)); + } break; default: LOG(FATAL) << "HandleCompare: unknown primitive type: " << PrimitiveType_Name(lhs->shape().element_type()); @@ -1450,11 +1767,9 @@ Status HloEvaluator::HandleCompare(HloInstruction* compare, HloOpcode opcode, return Status::OK(); } -Status HloEvaluator::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status HloEvaluator::HandleTuple(HloInstruction* tuple) { std::vector operand_literals; - for (auto operand : operands) { + for (auto operand : tuple->operands()) { operand_literals.push_back(&GetEvaluatedLiteralFor(operand)); } @@ -1462,11 +1777,11 @@ Status HloEvaluator::HandleTuple( return Status::OK(); } -Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) { +Status HloEvaluator::HandleGetTupleElement(HloInstruction* get_tuple_element) { const auto result_shape = get_tuple_element->shape(); const int64 index = get_tuple_element->tuple_index(); + auto operand = get_tuple_element->operand(0); TF_ASSIGN_OR_RETURN( auto inferred_return_shape, ShapeInference::InferGetTupleElementShape(operand->shape(), index)); diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h index 66a53e1fa5a219a60665198a03026ad36cc4c117..67b6e215fcb23598f1a8ab6212d6e7e58a64e976 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator.h +++ b/tensorflow/compiler/xla/service/hlo_evaluator.h @@ -84,6 +84,16 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Same as Evaluate, except returning nullptr on error. std::unique_ptr TryEvaluate(HloInstruction* instruction); + // Evaluates a single HLO instruction, substituting the given literals for + // some of the instruction's operands. + // + // For example, given instruction = op(A, B, C) and the map + // {A = x, C = y}, this evaluates op(x, B, y). + StatusOr> EvaluateWithSubstitutions( + const HloInstruction* instruction, + const std::unordered_map& + substitutions); + protected: // Templated DfsHloVisitor. Typically ReturnT here indicates the resulting // literal type of each evaluated Handle* method of a TypedVisitor. @@ -110,28 +120,20 @@ class HloEvaluator : public DfsHloVisitorWithDefault { // Status HandleParameter(HloInstruction* parameter) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; + Status HandleConstant(HloInstruction* constant) override; - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override; + Status HandleConcatenate(HloInstruction* concatenate) override; Status HandleReshape(HloInstruction* reshape) override; Status HandleTranspose(HloInstruction* transpose) override; - Status HandleIsFinite(HloInstruction* is_finite, - HloInstruction* operand) override; + Status HandleIsFinite(HloInstruction* is_finite) override; - Status HandleCompare(HloInstruction* compare, HloOpcode opcode, - HloInstruction* lhs, HloInstruction* rhs) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; + Status HandleCompare(HloInstruction* compare) override; + Status HandleTuple(HloInstruction* tuple) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleCopy(HloInstruction* copy) override; diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc index 010d38bbb48303a0ba13dbd7dc344d5167ad6e8f..85477af6fe26f53504c07204348566c16a24392c 100644 --- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc +++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc @@ -30,17 +30,18 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/test.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace { -class HloEvaluatorTest : public HloTestBase { +class HloEvaluatorTest : public HloVerifiedTestBase { protected: HloEvaluatorTest() { evaluator_ = MakeUnique(); } @@ -61,8 +62,7 @@ TEST_F(HloEvaluatorTest, DoesClamp) { auto c3 = b.AddInstruction(HloInstruction::CreateConstant(std::move(value))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kClamp, c1, c2, c3)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -88,8 +88,7 @@ TEST_F(HloEvaluatorTest, DoesSelect) { b.AddInstruction(HloInstruction::CreateConstant(std::move(on_false))); auto instruction = b.AddInstruction( HloInstruction::CreateTernary(shape, HloOpcode::kSelect, c1, c2, c3)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -111,8 +110,7 @@ TEST_F(HloEvaluatorTest, DoesAdd) { auto c2 = b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs))); auto instruction = b.AddInstruction( HloInstruction::CreateBinary(shape, HloOpcode::kAdd, c1, c2)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); @@ -124,111 +122,100 @@ TEST_F(HloEvaluatorTest, DoesAdd) { // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise divide with 2 operands. -TEST_F(HloEvaluatorTest, DoesDivide) { - { - auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); - auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); - - Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); - auto c2_s64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } - { - auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); - auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); - - Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); - auto c2_f64 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); - auto instruction = b.AddInstruction(HloInstruction::CreateBinary( - shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = - Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } +TEST_F(HloEvaluatorTest, DoesDivideInt64) { + auto lhs_s64 = Literal::CreateR2({{1, 0}, {-100, 4}}); + auto rhs_s64 = Literal::CreateR2({{2, 4}, {4, 4}}); + + Shape shape_s64 = ShapeUtil::MakeShape(S64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1_s64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_s64))); + auto c2_s64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_s64))); + auto instruction = b.AddInstruction(HloInstruction::CreateBinary( + shape_s64, HloOpcode::kDivide, c1_s64, c2_s64)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{0, 0}, {-25, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesDivideDouble) { + auto lhs_f64 = Literal::CreateR2({{1.0, 0.0}, {-100.0, 4.0}}); + auto rhs_f64 = Literal::CreateR2({{2.2, 4.0}, {4.0, 4.0}}); + + Shape shape_f64 = ShapeUtil::MakeShape(F64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1_f64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_f64))); + auto c2_f64 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_f64))); + auto instruction = b.AddInstruction(HloInstruction::CreateBinary( + shape_f64, HloOpcode::kDivide, c1_f64, c2_f64)); + module().AddEntryComputation(b.Build()); + + auto result = evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = + Literal::CreateR2({{0.45454545454545453, 0}, {-25, 1}}); + + LiteralTestUtil::ExpectEqual(*expected, *result); } // Verifies that HloEvaluator evaluates a HLO instruction that performs // element-wise abs op with 1 operand. -TEST_F(HloEvaluatorTest, DoesAbs) { - { - auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); - const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - std::unique_ptr result = - evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); - - auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } +TEST_F(HloEvaluatorTest, DoesAbsR2) { + auto operand = Literal::CreateR2({{1, -20}, {-100, 4}}); + const Shape& shape = ShapeUtil::MakeShape(S64, {2, 2}); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = + b.AddInstruction(HloInstruction::CreateUnary(shape, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + + std::unique_ptr result = + evaluator_->Evaluate(instruction, {}).ConsumeValueOrDie(); + + auto expected = Literal::CreateR2({{1, 20}, {100, 4}}); + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesAbsR0) { // For R0 literal. - { - const Shape& r0 = ShapeUtil::MakeShape(F32, {}); - auto operand = Literal::CreateR0(-1.0f); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = - b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); - auto expected = Literal::CreateR0(1.0f); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } + const Shape& r0 = ShapeUtil::MakeShape(F32, {}); + auto operand = Literal::CreateR0(-1.0f); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = + b.AddInstruction(HloInstruction::CreateUnary(r0, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); + auto expected = Literal::CreateR0(1.0f); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} +TEST_F(HloEvaluatorTest, DoesAbsR1WithZeroSize) { // For R1 literal with dimension of size 0. - { - Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); - auto operand = Literal::CreateR1({}); - HloComputation::Builder b(TestName()); - auto c1 = - b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); - auto instruction = b.AddInstruction( - HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); - - auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); - auto expected = Literal::CreateR1({}); - - LiteralTestUtil::ExpectEqual(*expected, *result); - } -} // namespace + Shape empty_r1 = ShapeUtil::MakeShape(F32, {0}); + auto operand = Literal::CreateR1({}); + HloComputation::Builder b(TestName()); + auto c1 = + b.AddInstruction(HloInstruction::CreateConstant(std::move(operand))); + auto instruction = b.AddInstruction( + HloInstruction::CreateUnary(empty_r1, HloOpcode::kAbs, c1)); + module().AddEntryComputation(b.Build()); + + auto result = evaluator_->Evaluate(instruction).ConsumeValueOrDie(); + auto expected = Literal::CreateR1({}); + + LiteralTestUtil::ExpectEqual(*expected, *result); +} // Verifies that HloEvaluator evaluates a HLO Computation with non-parameter nor // constant operands. @@ -252,8 +239,7 @@ TEST_F(HloEvaluatorTest, DoesTraverseInstructions) { b.AddInstruction(HloInstruction::CreateParameter(2, shape, "rhs2")); b.AddInstruction(HloInstruction::CreateBinary(shape, HloOpcode::kAdd, lhs_instruction, param_rhs2)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, args).ConsumeValueOrDie(); @@ -278,8 +264,7 @@ TEST_F(HloEvaluatorTest, DoesReshape) { const int64 permutation[] = {1, 2, 0, 4, 3}; b.AddInstruction( HloInstruction::CreateTranspose(shape, literal_instruction, permutation)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -302,8 +287,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcast) { HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, {1, 2})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -323,8 +307,7 @@ TEST_F(HloEvaluatorTest, DoesBroadcastScalar) { b.AddInstruction(HloInstruction::CreateBroadcast( output_literal->shape(), literal_instruction, /*broadcast_dimensions=*/{})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -342,11 +325,10 @@ TEST_F(HloEvaluatorTest, DoesConcatenateSimple) { std::vector operands = {operand1, operand2}; - Shape shape = ShapeUtil::MakeShape(S64, {2, 2}); + Shape shape = ShapeUtil::MakeShape(S64, {4, 2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -369,8 +351,7 @@ TEST_F(HloEvaluatorTest, ConcatenateHandlesShapeWithZeroElement) { Shape shape = ShapeUtil::MakeShape(S64, {2}); b.AddInstruction(HloInstruction::CreateConcatenate(shape, operands, 0)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -391,8 +372,7 @@ TEST_F(HloEvaluatorTest, ConvertWithSameLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -413,8 +393,7 @@ TEST_F(HloEvaluatorTest, ConvertWithDifferentLayout) { HloInstruction* constant = b.AddInstruction( HloInstruction::CreateConstant(std::move(input_literal))); b.AddInstruction(HloInstruction::CreateConvert(expected->shape(), constant)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -450,8 +429,7 @@ TEST_F(HloEvaluatorTest, Pad2DIntegerArrayWithZeroDimension) { Shape shape = ShapeUtil::MakeShape(S32, {5, 2}); auto pad_instruction = b.AddInstruction(HloInstruction::CreatePad( shape, operand_instruction, padding_value_instruction, padding_config)); - HloModule module(TestName()); - module.AddEntryComputation(b.Build()); + module().AddEntryComputation(b.Build()); auto result = evaluator_->Evaluate(pad_instruction).ConsumeValueOrDie(); @@ -478,8 +456,7 @@ TEST_F(HloEvaluatorTest, Pad4DFloatArrayWithInteriorPadding) { CreatePaddingConfig({{{1, 0, 2}}, {{0, 2, 1}}, {{0, 0, 0}}, {{0, 0, 0}}}); b.AddInstruction(HloInstruction::CreatePad( shape, input_instruction, pad_instruction, r4_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -524,8 +501,7 @@ TEST_F(HloEvaluatorTest, NegativePadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -571,8 +547,7 @@ TEST_F(HloEvaluatorTest, NegativeAndInteriorPadding2D) { pad_value_instruction, r2_padding_on_dim0_dim1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -608,8 +583,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) { Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -652,8 +626,7 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) { Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -694,8 +667,7 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) { Shape shape = ShapeUtil::MakeShape(F32, {4, 2}); b.AddInstruction(HloInstruction::CreateBinary( shape, HloOpcode::kDot, lhs_instruction, rhs_instruction)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -735,8 +707,10 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.set_kernel_output_feature_dimension(0); @@ -746,8 +720,7 @@ TEST_F(HloEvaluatorTest, SimpleConv1D) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -802,8 +775,7 @@ TEST_F(HloEvaluatorTest, Simple4x4Conv2DWith2x2Kernel) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -867,8 +839,10 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { *window.add_dimensions() = dim; ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(2); - dnums.set_feature_dimension(0); + dnums.set_input_batch_dimension(2); + dnums.set_output_batch_dimension(2); + dnums.set_input_feature_dimension(0); + dnums.set_output_feature_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(3); @@ -880,8 +854,7 @@ TEST_F(HloEvaluatorTest, Conv2DGeneralDimensions) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -939,8 +912,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 7, 7}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1004,8 +976,7 @@ TEST_F(HloEvaluatorTest, DilatedBaseConv2DWithLowAndHighPadding) { const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 8, 8}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1077,8 +1048,7 @@ TEST_F(HloEvaluatorTest, const Shape& shape = ShapeUtil::MakeShape(F32, {1, 1, 9, 3}); b.AddInstruction(HloInstruction::CreateConvolve( shape, lhs_instruction, rhs_instruction, window, dnums)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1126,15 +1096,14 @@ TEST_F(HloEvaluatorTest, ReduceAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Shape shape = ShapeUtil::MakeShape(F32, {2}); b.AddInstruction( HloInstruction::CreateReduce(shape, arg_instruction, init_value, /*dimensions_to_reduce=*/{1}, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1170,8 +1139,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); max_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kMaximum, param_lhs, param_rhs)); - HloModule module(TestName()); - auto max_func = module.AddEmbeddedComputation(max_computation.Build()); + auto max_func = module().AddEmbeddedComputation(max_computation.Build()); Window window; WindowDimension dim; @@ -1188,7 +1156,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowMax) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, max_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1222,8 +1190,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Window window; WindowDimension dim; @@ -1246,7 +1213,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1276,8 +1243,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { HloInstruction::CreateParameter(1, scalar_shape, "rhs")); add_computation.AddInstruction(HloInstruction::CreateBinary( scalar_shape, HloOpcode::kAdd, param_lhs, param_rhs)); - HloModule module(TestName()); - auto add_func = module.AddEmbeddedComputation(add_computation.Build()); + auto add_func = module().AddEmbeddedComputation(add_computation.Build()); Window window; @@ -1308,7 +1274,7 @@ TEST_F(HloEvaluatorTest, ReduceWindowAdd6D) { b.AddInstruction(HloInstruction::CreateReduceWindow( shape, arg_instruction, init_value, window, add_func)); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1339,8 +1305,7 @@ TEST_F(HloEvaluatorTest, StridedSlice) { /*start_indices=*/{0, 2}, /*limit_indices=*/{3, 5}, /*strides=*/{2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1374,8 +1339,7 @@ TEST_F(HloEvaluatorTest, DynamicSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1411,8 +1375,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceModSlice) { Shape shape = ShapeUtil::MakeShape(F32, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicSlice(shape, operand, start_indices, {2, 3})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1449,8 +1412,7 @@ TEST_F(HloEvaluatorTest, DynamicSliceUpdate) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateDynamicUpdateSlice( shape, operand, update, start_indices)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1486,8 +1448,7 @@ TEST_F(HloEvaluatorTest, SetAndGetTuples) { Shape shape = ShapeUtil::MakeShape(F64, {2, 3}); b.AddInstruction(HloInstruction::CreateGetTupleElement(shape, tuple, 1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1526,8 +1487,7 @@ TEST_F(HloEvaluatorTest, SetAndGetNestedTuples) { b.AddInstruction( HloInstruction::CreateGetTupleElement(tuple2->shape(), outer_tuple, 1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1567,8 +1527,7 @@ TEST_F(HloEvaluatorTest, Reverse) { const Shape shape = ShapeUtil::MakeShape(F32, {4, 3, 2, 1}); b.AddInstruction(HloInstruction::CreateReverse(shape, operand, {0, 1})); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(b.Build()); + auto computation = module().AddEntryComputation(b.Build()); std::unique_ptr result = evaluator_->Evaluate(*computation, {}).ConsumeValueOrDie(); @@ -1596,5 +1555,50 @@ TEST_F(HloEvaluatorTest, Reverse) { LiteralTestUtil::ExpectEqual(*expected, *result); } +TEST_F(HloEvaluatorTest, EvaluateWithSubstitutions) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + HloInstruction* param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, param0, param0)); + HloInstruction* add = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param0, square)); + + // Evaluate add with param0 = {1, 2, 3, 4}, square = {10, 20, 30, 40}. + HloEvaluator evaluator; + auto result = evaluator.EvaluateWithSubstitutions( + add, {{param0, Literal::CreateR1({1, 2, 3, 4}).get()}, + {square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + TF_ASSERT_OK(result.status()); + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), + *result.ValueOrDie()); +} + +// Check that EvaluateWithSubstitutions works if one of the operands to the op +// we're evaluating is a constant. +TEST_F(HloEvaluatorTest, EvaluateWithSubstitutionsWithConstantOperand) { + HloComputation::Builder b(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4}); + + HloInstruction* param0 = + b.AddInstruction(HloInstruction::CreateParameter(0, shape, "param0")); + HloInstruction* square = b.AddInstruction(HloInstruction::CreateBinary( + shape, HloOpcode::kMultiply, param0, param0)); + HloInstruction* constant = b.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + HloInstruction* add = b.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, constant, square)); + + // Evaluate add with square = {10, 20, 30, 40}. + HloEvaluator evaluator; + auto result = evaluator.EvaluateWithSubstitutions( + add, {{square, Literal::CreateR1({10, 20, 30, 40}).get()}}); + TF_ASSERT_OK(result.status()); + LiteralTestUtil::ExpectEqual(*Literal::CreateR1({11, 22, 33, 44}), + *result.ValueOrDie()); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index eaeb352183bdf6cc7f4a164c31af4f641e37440e..bf19bc9309b95f09fc5a36daf3e150f5191d1b8e 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -27,13 +27,13 @@ limitations under the License. namespace xla { -void HloExecutionProfile::AddProfileResult(const HloInstruction* hlo, +void HloExecutionProfile::SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken) { hlo_to_cycles_taken_[hlo] = cycles_taken; profiled_computations_.insert(hlo->parent()); } -uint64 HloExecutionProfile::GetProfileResult(const HloInstruction& hlo) const { +uint64 HloExecutionProfile::GetCyclesTakenBy(const HloInstruction& hlo) const { auto iter = hlo_to_cycles_taken_.find(&hlo); if (iter == hlo_to_cycles_taken_.end()) { return 0; diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.h b/tensorflow/compiler/xla/service/hlo_execution_profile.h index a980c1617f395fc6668b8f8739e04d18fd1b689e..cdce77cff427da376109db77c65ec70364e36140 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.h +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.h @@ -36,11 +36,11 @@ class HloExecutionProfile { using DeviceDescription = perftools::gputools::DeviceDescription; // Record how many cycles this HLO took to execute. - void AddProfileResult(const HloInstruction* hlo, uint64 cycles_taken); + void SetCyclesTakenBy(const HloInstruction* hlo, uint64 cycles_taken); // Returns how many cycles this HLO took to execute. Profiling information // may not be available for some instructions in which case zero is returned. - uint64 GetProfileResult(const HloInstruction& hlo) const; + uint64 GetCyclesTakenBy(const HloInstruction& hlo) const; // Return the number of cycles this computation took to execute. uint64 total_cycles_executed(const HloComputation& computation) const { diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index cf1ae07ee4cfe3b02d99f064e8b2b298c5afe267..fd162622ce2a56bcfbcd4fa1c56d5afc56249a8f 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -231,9 +231,9 @@ string HtmlLikeStringSanitize(tensorflow::StringPiece s) { // commutative, we also support them with param0 and param1 swapped. // // This is useful primarily for reduce and map nodes. These take a -// subcomputation which is almost always one of the four above, and pattern -// matching it to a short string lets us tell the user what the subcomputation -// is without drawing it as a graph. +// subcomputation which is almost always one of the above, and pattern matching +// it to a short string lets us tell the user what the subcomputation is without +// drawing it as a graph. optional MatchTrivialComputation(const HloComputation* computation) { if (computation->instruction_count() != 3) { return nullopt; @@ -342,6 +342,11 @@ class HloDotDumper { bool ShouldShowSubcomputation(const HloComputation* subcomp); bool ShouldShowFusionSubcomputation(const HloInstruction* instr); + + // We omit some nodes from the graph, instead drawing them inlined into the + // nodes that use them. + bool ShouldMergeIntoUsers(const HloInstruction* instr) const; + string DumpSubcomputation(const HloComputation* subcomp, const HloInstruction* parent_instr); string DumpComputation(const HloComputation* comp); @@ -352,9 +357,24 @@ class HloDotDumper { string GetInstructionNodeLabel(const HloInstruction* instr); string GetInstructionNodeMetadata(const HloInstruction* instr); string GetInstructionNodeExtraInfo(const HloInstruction* instr); - string GetInstructionNodeInlinedConstants(const HloInstruction* instr); + string GetInstructionNodeInlinedOperands(const HloInstruction* instr); void AddInstructionIncomingEdges(const HloInstruction* instr); + // For most instructions, GetNodeForEdge(instr) returns instr. + // + // The exception is fusion nodes. For these, we walk up the chain of nested + // fusion nodes starting at instr until we reach a node that either (a) isn't + // a fusion node, or (b) is a fusion node for which + // ShouldShowFusionSubcomputation is false. + // + // We do this because fusion nodes are expanded inline -- if + // ShouldShowFusionSubcomputation is true, the fusion node won't be present in + // the graph. + // + // In general when you want to draw an edge from A to B, you should actually + // draw an edge from GetNodeForEdge(A) to GetNodeForEdge(B). + const HloInstruction* GetNodeForEdge(const HloInstruction* instr); + // If instr has just one computation and it's trivial (e.g. "return param0 + // param1"), returns a string you can put into the node's body that names the // subcomputation, e.g. "Subcomputation: add". @@ -537,11 +557,9 @@ bool HloDotDumper::ShouldShowSubcomputation(const HloComputation* subcomp) { } // Show the subcomputation if we're showing any of its members. - return std::any_of(computation_->instructions().begin(), - computation_->instructions().end(), - [&](const std::unique_ptr& instr) { - return filter_.Show(instr.get()); - }); + return std::any_of( + computation_->instructions().begin(), computation_->instructions().end(), + [&](const HloInstruction* instr) { return filter_.Show(instr); }); } string HloDotDumper::DumpSubcomputation(const HloComputation* subcomp, @@ -592,16 +610,15 @@ tooltip = " "; // belongs to a fusion node, it's drawn in place of the fusion instruction, // so there's no need to link those. if (parent_instr->opcode() != HloOpcode::kFusion) { - VLOG(2) << "Edge: from " << subcomp->root_instruction()->name() << " to " - << parent_instr->name() << " as " << next_edge_id_; - edge_ids_.insert( - {{subcomp->root_instruction(), parent_instr}, next_edge_id_++}); + const HloInstruction* from = GetNodeForEdge(subcomp->root_instruction()); + VLOG(2) << "Edge: from " << from->name() << " to " << parent_instr->name() + << " as " << next_edge_id_; + edge_ids_.insert({{from, parent_instr}, next_edge_id_++}); const char* edge_fmt = R"(%s -> %s [ltail="%s", style="dashed" tooltip="%s -> %s"];)"; - edges_.push_back( - Printf(edge_fmt, InstructionId(subcomp->root_instruction()), - InstructionId(parent_instr), SubcomputationId(subcomp), - subcomp->name(), parent_instr->name())); + edges_.push_back(Printf( + edge_fmt, InstructionId(from), InstructionId(parent_instr), + SubcomputationId(subcomp), subcomp->name(), parent_instr->name())); } string computation = @@ -612,33 +629,25 @@ tooltip = " "; string HloDotDumper::DumpComputation(const HloComputation* comp) { string g; - for (const auto& instr : comp->instructions()) { - if (!filter_.Show(instr.get())) { + for (const auto* instr : comp->instructions()) { + if (!filter_.Show(instr)) { continue; } // Dump subcomputations within instr. for (const HloComputation* subcomp : instr->called_computations()) { if (ShouldShowSubcomputation(subcomp)) { - StrAppend(&g, DumpSubcomputation(subcomp, instr.get())); + StrAppend(&g, DumpSubcomputation(subcomp, instr)); } } - StrAppend(&g, DumpInstruction(instr.get())); + StrAppend(&g, DumpInstruction(instr)); } return g; } string HloDotDumper::DumpRootTag() { - HloInstruction* from = computation_->root_instruction(); - - // Fusion nodes are expanded inline, so if root is an expanded fusion node, - // walk up the graph until we find a node that isn't. - while (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } - + const HloInstruction* from = GetNodeForEdge(computation_->root_instruction()); auto from_id = InstructionId(from); if (!filter_.Show(from)) { @@ -670,12 +679,42 @@ string HloDotDumper::DumpRootTag() { to_id, node_body, node_shape, NodeColorAttributes(color)); } +bool HloDotDumper::ShouldMergeIntoUsers(const HloInstruction* instr) const { + // If a node: + // + // - is a tuple-shaped parameter, + // - is not a parameter to a fusion node, + // - has at least kMinUsersToOmit users shown, and + // - all of the shown users are get-tuple-elements, + // + // then we omit it from the graph, merging it with its users. + // + // This helps us handle the common case where a while loop body has one big + // tuple-shaped parameter. + const int kMinUsersToOmit = 3; + return instr->opcode() == HloOpcode::kParameter && + ShapeUtil::IsTuple(instr->shape()) && !instr->IsFused() && + std::count_if(instr->users().begin(), instr->users().end(), + [&](const HloInstruction* user) { + return filter_.Show(user); + }) > kMinUsersToOmit && + std::all_of(instr->users().begin(), instr->users().end(), + [&](const HloInstruction* user) { + return !filter_.Show(user) || + user->opcode() == HloOpcode::kGetTupleElement; + }); +} + string HloDotDumper::DumpInstruction(const HloInstruction* instr) { // We don't display constants as separate nodes; they're merged into their // users. if (instr->opcode() == HloOpcode::kConstant) { return ""; } + // Skip this node if it's merged into its users. + if (ShouldMergeIntoUsers(instr)) { + return ""; + } // Omit the fusion node if its subcomputation is drawn, since the // subcomputation will be drawn inline. if (instr->opcode() == HloOpcode::kFusion && @@ -691,7 +730,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { string node_label = GetInstructionNodeLabel(instr); string node_metadata = GetInstructionNodeMetadata(instr); string extra_info = GetInstructionNodeExtraInfo(instr); - string inlined_constants = GetInstructionNodeInlinedConstants(instr); + string inlined_constants = GetInstructionNodeInlinedOperands(instr); string trivial_subcomputation = GetInstructionTrivialComputationStr(instr); AddInstructionIncomingEdges(instr); @@ -719,7 +758,7 @@ string HloDotDumper::DumpInstruction(const HloInstruction* instr) { NodeColorAttributes(color)); } -string HloDotDumper::GetInstructionNodeInlinedConstants( +string HloDotDumper::GetInstructionNodeInlinedOperands( const HloInstruction* instr) { auto stringify_constant = [](const HloInstruction* constant) { if (ShapeUtil::IsEffectiveScalar(constant->shape())) { @@ -728,10 +767,14 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( return Printf("%s (%s)", constant->literal().GetAsString(elem_idx), ShapeUtil::HumanString(constant->shape())); } + string constant_name; if (tensorflow::StringPiece(constant->name()).starts_with("%constant")) { - return constant->name(); + constant_name = constant->name(); + } else { + constant_name = StrCat("constant ", constant->name()); } - return StrCat("constant ", constant->name()); + return Printf("%s %s", constant_name, + ShapeUtil::HumanString(constant->shape())); }; // Special case: If instr is a parameter to a fusion node, check whether the @@ -748,16 +791,44 @@ string HloDotDumper::GetInstructionNodeInlinedConstants( std::vector lines; for (int64 i = 0; i < instr->operand_count(); ++i) { const HloInstruction* operand = instr->operand(i); - if (operand->opcode() != HloOpcode::kConstant) { - continue; + optional operand_str; + if (operand->opcode() == HloOpcode::kConstant) { + operand_str = stringify_constant(operand); + } else if (ShouldMergeIntoUsers(operand)) { + // Special case: If the operand is a parameter, use its parameter number + // rather than its name, because that's generally how people think of the + // node. + if (operand->opcode() == HloOpcode::kParameter) { + operand_str = Printf("Parameter %lld", operand->parameter_number()); + } else { + operand_str = operand->name(); + } + } + + if (operand_str) { + if (instr->operand_count() > 1) { + lines.push_back(Printf("operand %lld = %s", i, *operand_str)); + } else { + lines.push_back(Printf("operand = %s", *operand_str)); + } } - lines.push_back( - Printf("operand %lld = %s", i, stringify_constant(operand))); } return Join(lines, "
"); } ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { + const auto kParameterColor = kOrange; + + // Special case: If this instruction has a parameter merged into it, paint it + // the same color as a parameter. + if (std::any_of(instr->operands().begin(), instr->operands().end(), + [&](const HloInstruction* operand) { + return operand->opcode() == HloOpcode::kParameter && + ShouldMergeIntoUsers(operand); + })) { + return kParameterColor; + } + // Pick different colors or shapes for instructions which are particularly // expensive (eg, dot) and those which are unusual in some way or unique // (eg, parameter). @@ -765,8 +836,10 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kAbs: case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: + case HloOpcode::kAtan2: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kConvert: case HloOpcode::kCos: case HloOpcode::kDivide: @@ -775,13 +848,13 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: - case HloOpcode::kIndex: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -789,8 +862,11 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: + case HloOpcode::kReal: case HloOpcode::kRemainder: - case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSlice: @@ -798,22 +874,46 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kRng: - case HloOpcode::kBroadcast: - case HloOpcode::kTranspose: + // De-emphasize scalar-shaped elementwise ops -- they're generally + // uninteresting. + if (ShapeUtil::IsEffectiveScalar(instr->shape())) { + return kWhite; + } return kYellow; case HloOpcode::kBitcast: case HloOpcode::kTuple: case HloOpcode::kTrace: case HloOpcode::kGetTupleElement: return kWhite; + case HloOpcode::kBroadcast: + // De-emphasize nodes which broadcast a scalar within a fusion node -- + // these are essentially free. + if (instr->IsFused() && + ShapeUtil::IsEffectiveScalar(instr->operand(0)->shape())) { + return kWhite; + } + return kGreen; case HloOpcode::kConcatenate: case HloOpcode::kCopy: case HloOpcode::kDynamicSlice: - case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kPad: case HloOpcode::kReshape: case HloOpcode::kReverse: - case HloOpcode::kUpdate: + case HloOpcode::kSelect: + case HloOpcode::kTranspose: + // De-emphasize scalar-shaped data movement ops and all data movement ops + // inside fusion nodes, both of which are essentially free. + if (ShapeUtil::IsEffectiveScalar(instr->shape()) || instr->IsFused()) { + return kWhite; + } + return kGreen; + case HloOpcode::kDynamicUpdateSlice: + // Unlike the data-movement ops above, dynamic-update-slice is not ~free + // inside of fusion nodes, so we de-emphasize it only if it's + // scalar-shaped. + if (ShapeUtil::IsEffectiveScalar(instr->shape())) { + return kWhite; + } return kGreen; case HloOpcode::kConvolution: case HloOpcode::kDot: @@ -821,7 +921,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kReducePrecision: return kRed; case HloOpcode::kParameter: - return kOrange; + return kParameterColor; case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: @@ -926,6 +1026,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { [](int64 stride) { return stride == 1; }) ? "" : StrCat("stride=", VectorString(instr->slice_strides())); + case HloOpcode::kSend: + case HloOpcode::kRecv: + return StrCat("channel_id=", instr->channel_id()); default: return ""; } @@ -935,7 +1038,9 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { if (!opcode_specific_info.empty()) { lines.push_back(opcode_specific_info); } - + if (instr->has_sharding()) { + lines.push_back(StrCat("sharding=", instr->sharding().ToString())); + } // Show the shape and layout of the instruction, unless it's an inlined fusion // node -- there the shape and layout is present in the output node. if (instr->opcode() != HloOpcode::kFusion || @@ -965,7 +1070,7 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { lines.push_back(Printf("[%p]", instr)); } if (profile_ != nullptr) { - double hlo_cycles_executed = profile_->GetProfileResult(*instr); + double hlo_cycles_executed = profile_->GetCyclesTakenBy(*instr); double total_cycles_executed = profile_->total_cycles_executed(*instr->parent()); if (hlo_cycles_executed > 0 && total_cycles_executed > 0) { @@ -980,14 +1085,10 @@ string HloDotDumper::GetInstructionNodeExtraInfo(const HloInstruction* instr) { void HloDotDumper::AddInstructionIncomingEdges(const HloInstruction* instr) { auto add_edge = [&](const HloInstruction* from, const HloInstruction* to, int64 operand_num, bool control_edge = false) { - // Fusion nodes' subcomputations are displayed inline, so if 'from' is a - // fusion node and the node's subcomputation is shown, we draw our edge - // starting at the fusion node's root instead of at the fusion node itself. - if (from->opcode() == HloOpcode::kFusion && - ShouldShowFusionSubcomputation(from)) { - from = from->fused_expression_root(); - } - if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant) { + from = GetNodeForEdge(from); + + if (!filter_.Show(from) || from->opcode() == HloOpcode::kConstant || + ShouldMergeIntoUsers(from)) { return; } VLOG(2) << "Adding edge from " << from->name() << " to " << to->name() @@ -1053,6 +1154,15 @@ string HloDotDumper::GetInstructionTrivialComputationStr( return Join(lines, "
"); } +const HloInstruction* HloDotDumper::GetNodeForEdge( + const HloInstruction* instr) { + while (instr->opcode() == HloOpcode::kFusion && + ShouldShowFusionSubcomputation(instr)) { + instr = instr->fused_expression_root(); + } + return instr; +} + tensorflow::mutex& RendererMutex() { static tensorflow::mutex* mu = new tensorflow::mutex; return *mu; @@ -1281,7 +1391,8 @@ void DumpText(const HloModule& module, const string& label, string filename = do_prefix ? StrCat(prefix, "-", label, ".txt") : StrCat(label, ".txt"); string path = JoinPath(directory_path, filename); - TF_CHECK_OK(WriteStringToFile(env, path, module.ToString())); + TF_CHECK_OK(WriteStringToFile( + env, path, module.ToString(/*include_large_constants=*/true))); LOG(INFO) << "dumping module '" << module.name() << "' to " << path; } diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc index 4015ee6cace2a1e66110e1db98fea547e57939be..7b0f937f383a416f805a799bd6787afe15b324b0 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper_test.cc @@ -95,8 +95,7 @@ TEST(HloGraphDumperTest, NestedFusion) { {root_computation, // inner_fusion->fused_instructions_computation(), outer_fusion->fused_instructions_computation()}) { - for (const std::unique_ptr& instruction : - computation->instructions()) { + for (const HloInstruction* instruction : computation->instructions()) { EXPECT_THAT(graph, HasSubstr(instruction->name())); } } @@ -105,10 +104,10 @@ TEST(HloGraphDumperTest, NestedFusion) { // care that the outer nodes are omitted -- whether they are or not is based // fiddly heuristics -- but we do care that the node we asked for is printed. const HloInstruction* inner_sum = nullptr; - for (const std::unique_ptr& instruction : + for (const HloInstruction* instruction : inner_fusion->fused_instructions_computation()->instructions()) { if (instruction->opcode() == HloOpcode::kAdd) { - inner_sum = instruction.get(); + inner_sum = instruction; break; } } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 4f2cf1c2b88cb2b1ae264d74d15c2541dab15db5..5107ac782d7c93dfa17969338bf97c9fd9bb1516 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" +#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" @@ -47,6 +47,101 @@ using ::tensorflow::str_util::Join; using ::tensorflow::strings::StrAppend; using ::tensorflow::strings::StrCat; +/* static */ +StatusOr> HloInstruction::CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map) { + TF_RET_CHECK(!proto.opcode().empty()); + TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode())); + TF_RET_CHECK(proto.has_shape()); + + auto instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); + for (const string& operand_name : proto.operand_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, operand_name)) + << "No instruction named " << operand_name; + instruction->AppendOperand(instruction_map.at(operand_name)); + } + for (const string& predecessor_name : proto.control_predecessor_names()) { + TF_RET_CHECK(ContainsKey(instruction_map, predecessor_name)) + << "No instruction named " << predecessor_name; + TF_RETURN_IF_ERROR(instruction_map.at(predecessor_name) + ->AddControlDependencyTo(instruction.get())); + } + + // In the proto, fused computations are held exclusively within the + // HloInstructionProto and do not appear as an HloComputationProto within the + // HloModuleProto. + if (instruction->opcode() == HloOpcode::kFusion) { + TF_RET_CHECK(proto.has_fused_instructions_computation()); + TF_RET_CHECK(!proto.fusion_kind().empty()); + TF_ASSIGN_OR_RETURN(instruction->fusion_kind_, + StringToFusionKind(proto.fusion_kind())); + TF_ASSIGN_OR_RETURN( + std::unique_ptr fused_computation, + HloComputation::CreateFromProto( + module, proto.fused_instructions_computation(), computation_map, + /*fusion_instruction=*/instruction.get())); + instruction->called_computations_.push_back( + module->AddEmbeddedComputation(std::move(fused_computation))); + } else { + for (const string& computation_name : proto.called_computation_names()) { + TF_RET_CHECK(ContainsKey(*computation_map, computation_name)) + << "No computation named " << computation_name; + instruction->called_computations_.push_back( + computation_map->at(computation_name)); + } + } + + TF_RET_CHECK(!proto.name().empty()); + instruction->name_ = proto.name(); + + instruction->metadata_ = proto.metadata(); + if (proto.has_literal()) { + instruction->literal_ = MakeUnique(proto.literal()); + } + instruction->parameter_number_ = proto.parameter_number(); + instruction->parameter_name_ = proto.parameter_name(); + + instruction->tuple_index_ = proto.tuple_index(); + for (int64 dimension : proto.dimensions()) { + instruction->dimensions_.push_back(dimension); + } + if (proto.has_window()) { + instruction->window_ = MakeUnique(proto.window()); + } + if (proto.has_convolution_dimension_numbers()) { + instruction->convolution_dimension_numbers_ = + MakeUnique( + proto.convolution_dimension_numbers()); + } + for (const HloInstructionProto::SliceDimensions& slice_dimensions : + proto.slice_dimensions()) { + instruction->slice_starts_.push_back(slice_dimensions.start()); + instruction->slice_limits_.push_back(slice_dimensions.limit()); + instruction->slice_strides_.push_back(slice_dimensions.stride()); + } + instruction->exponent_bits_ = proto.exponent_bits(); + instruction->mantissa_bits_ = proto.mantissa_bits(); + for (int64 dynamic_slice_size : proto.dynamic_slice_sizes()) { + instruction->dynamic_slice_sizes_.push_back(dynamic_slice_size); + } + if (proto.has_padding_config()) { + instruction->padding_config_ = + MakeUnique(proto.padding_config()); + } + instruction->outfeed_config_ = proto.outfeed_config(); + instruction->distribution_ = proto.distribution(); + instruction->epsilon_ = proto.epsilon(); + instruction->feature_index_ = proto.feature_index(); + instruction->channel_id_ = proto.channel_id(); + instruction->infeed_config_ = proto.infeed_config(); + instruction->custom_call_target_ = proto.custom_call_target(); + instruction->outfeed_shape_ = proto.outfeed_shape(); + + return std::move(instruction); +} + /* static */ std::unique_ptr HloInstruction::CreateParameter( int64 parameter_number, const Shape& shape, const string& name) { auto instruction = @@ -124,10 +219,12 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSort: @@ -146,23 +243,28 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateBinary: opcodes of binary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kAdd): - case (HloOpcode::kDivide): - case (HloOpcode::kDot): - case (HloOpcode::kEq): - case (HloOpcode::kGe): - case (HloOpcode::kGt): - case (HloOpcode::kLe): - case (HloOpcode::kLt): - case (HloOpcode::kMaximum): - case (HloOpcode::kMinimum): - case (HloOpcode::kMultiply): - case (HloOpcode::kNe): - case (HloOpcode::kPower): - case (HloOpcode::kRemainder): - case (HloOpcode::kSubtract): - case (HloOpcode::kLogicalAnd): - case (HloOpcode::kLogicalOr): + case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kDivide: + case HloOpcode::kComplex: + case HloOpcode::kDot: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kMultiply: + case HloOpcode::kNe: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kSubtract: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: break; default: LOG(FATAL) << "Invalid binary instruction opcode " @@ -618,10 +720,12 @@ void HloInstruction::MergeFusionInstructionIntoMultiOutput( // Fuse the root instruction and generate multiple outputs. FuseInstructionIntoMultiOutput(unfused_root); + TF_CHECK_OK(unfused_root->parent()->RemoveInstruction(unfused_root)); // The rest instructions are of normal fusing. for (int64 i = 1; i < unfused_instructions.size(); i++) { auto instruction = unfused_instructions[i]; FuseInstruction(instruction); + TF_CHECK_OK(instruction->parent()->RemoveInstruction(instruction)); } } @@ -857,13 +961,16 @@ bool HloInstruction::HasSideEffect() const { std::unique_ptr HloInstruction::CloneWithNewOperands( const Shape& shape, - tensorflow::gtl::ArraySlice new_operands) { + tensorflow::gtl::ArraySlice new_operands, + HloModule* module) const { VLOG(3) << "CloneWithNewOperands:\n " << ToString(); VLOG(3) << " new operands:"; for (const HloInstruction* new_operand : new_operands) { VLOG(3) << " " << new_operand->name(); } + std::unique_ptr clone; + // Explicitly call the factory for the instruction type. This is more robust // in the face of code changes than copying fields explicitly. This also // properly sets the user fields of the operands. @@ -876,19 +983,24 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kExp: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kFloor: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); - return CreateUnary(shape, opcode_, new_operands[0]); + clone = CreateUnary(shape, opcode_, new_operands[0]); + break; // Binary ops. case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kMultiply: case HloOpcode::kSubtract: @@ -903,132 +1015,173 @@ std::unique_ptr HloInstruction::CloneWithNewOperands( case HloOpcode::kMinimum: case HloOpcode::kPower: case HloOpcode::kRemainder: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: CHECK_EQ(new_operands.size(), 2); - return CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); + clone = CreateBinary(shape, opcode_, new_operands[0], new_operands[1]); + break; // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: CHECK_EQ(new_operands.size(), 3); - return CreateTernary(shape, opcode_, new_operands[0], new_operands[1], - new_operands[2]); + clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], + new_operands[2]); + break; // Other supported ops. case HloOpcode::kBroadcast: CHECK_EQ(new_operands.size(), 1); - return CreateBroadcast(shape, new_operands[0], dimensions_); + clone = CreateBroadcast(shape, new_operands[0], dimensions_); + break; case HloOpcode::kCall: - return CreateCall(shape, new_operands, to_apply()); + clone = CreateCall(shape, new_operands, to_apply()); + break; case HloOpcode::kCustomCall: - return CreateCustomCall(shape, new_operands, custom_call_target_); + clone = CreateCustomCall(shape, new_operands, custom_call_target_); + break; case HloOpcode::kConcatenate: - return CreateConcatenate(shape, new_operands, dimensions(0)); + clone = CreateConcatenate(shape, new_operands, dimensions(0)); + break; case HloOpcode::kConvert: CHECK_EQ(new_operands.size(), 1); - return CreateConvert(shape, new_operands[0]); + clone = CreateConvert(shape, new_operands[0]); + break; case HloOpcode::kReducePrecision: CHECK_EQ(new_operands.size(), 1); - return CreateReducePrecision(shape, new_operands[0], exponent_bits_, - mantissa_bits_); + clone = CreateReducePrecision(shape, new_operands[0], exponent_bits_, + mantissa_bits_); + break; case HloOpcode::kConvolution: CHECK_EQ(new_operands.size(), 2); - return CreateConvolve(shape, new_operands[0], new_operands[1], *window_, - *convolution_dimension_numbers_); + clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_, + *convolution_dimension_numbers_); + break; case HloOpcode::kCrossReplicaSum: CHECK_EQ(new_operands.size(), 1); - return CreateCrossReplicaSum(shape, new_operands[0]); + clone = CreateCrossReplicaSum(shape, new_operands[0]); + break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); - return CreateGetTupleElement(shape, new_operands[0], tuple_index()); + clone = CreateGetTupleElement(shape, new_operands[0], tuple_index()); + break; case HloOpcode::kMap: - return CreateMap(shape, new_operands, to_apply()); + clone = CreateMap(shape, new_operands, to_apply()); + break; case HloOpcode::kPad: CHECK_EQ(new_operands.size(), 2); - return CreatePad(shape, new_operands[0], new_operands[1], - *padding_config_); + clone = + CreatePad(shape, new_operands[0], new_operands[1], *padding_config_); + break; case HloOpcode::kReduce: CHECK_EQ(new_operands.size(), 2); - return CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, - to_apply()); + clone = CreateReduce(shape, new_operands[0], new_operands[1], dimensions_, + to_apply()); + break; case HloOpcode::kReduceWindow: CHECK_EQ(new_operands.size(), 2); - return CreateReduceWindow(shape, new_operands[0], new_operands[1], - *window_, to_apply()); + clone = CreateReduceWindow(shape, new_operands[0], new_operands[1], + *window_, to_apply()); + break; case HloOpcode::kSelectAndScatter: CHECK_EQ(new_operands.size(), 3); - return CreateSelectAndScatter(shape, new_operands[0], select(), *window_, - new_operands[1], new_operands[2], - scatter()); + clone = + CreateSelectAndScatter(shape, new_operands[0], select(), *window_, + new_operands[1], new_operands[2], scatter()); + break; case HloOpcode::kReverse: CHECK_EQ(new_operands.size(), 1); - return CreateReverse(shape, new_operands[0], dimensions_); + clone = CreateReverse(shape, new_operands[0], dimensions_); + break; case HloOpcode::kRng: - return CreateRng(shape, distribution_, new_operands); + clone = CreateRng(shape, distribution_, new_operands); + break; case HloOpcode::kReshape: CHECK_EQ(new_operands.size(), 1); - return CreateReshape(shape, new_operands[0]); + clone = CreateReshape(shape, new_operands[0]); + break; case HloOpcode::kSlice: CHECK_EQ(new_operands.size(), 1); - return CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, - slice_strides_); + clone = CreateSlice(shape, new_operands[0], slice_starts_, slice_limits_, + slice_strides_); + break; case HloOpcode::kDynamicSlice: - return CreateDynamicSlice(shape, new_operands[0], new_operands[1], - dynamic_slice_sizes_); + clone = CreateDynamicSlice(shape, new_operands[0], new_operands[1], + dynamic_slice_sizes_); + break; case HloOpcode::kDynamicUpdateSlice: CHECK_EQ(new_operands.size(), 3); - return CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], - new_operands[2]); + clone = CreateDynamicUpdateSlice(shape, new_operands[0], new_operands[1], + new_operands[2]); + break; case HloOpcode::kTranspose: CHECK_EQ(new_operands.size(), 1); - return CreateTranspose(shape, new_operands[0], dimensions_); + clone = CreateTranspose(shape, new_operands[0], dimensions_); + break; case HloOpcode::kTuple: - return CreateTuple(new_operands); + clone = CreateTuple(new_operands); + *clone->mutable_shape() = shape; + break; case HloOpcode::kWhile: CHECK_EQ(new_operands.size(), 1); - return CreateWhile(shape, while_condition(), while_body(), - new_operands[0]); + clone = + CreateWhile(shape, while_condition(), while_body(), new_operands[0]); + break; case HloOpcode::kConstant: - return CreateConstant(literal_->CloneToUnique()); + clone = CreateConstant(literal_->CloneToUnique()); + break; case HloOpcode::kFusion: - return CloneFusionWithNewOperands(shape, new_operands); + clone = CloneFusionWithNewOperands(shape, new_operands, module); + break; case HloOpcode::kParameter: - return CreateParameter(parameter_number_, shape, parameter_name_); + clone = CreateParameter(parameter_number_, shape, parameter_name_); + break; case HloOpcode::kBatchNormTraining: CHECK_EQ(new_operands.size(), 3); - return CreateBatchNormTraining(shape, new_operands[0], new_operands[1], - new_operands[2], epsilon(), - feature_index()); - + clone = + CreateBatchNormTraining(shape, new_operands[0], new_operands[1], + new_operands[2], epsilon(), feature_index()); + break; case HloOpcode::kBatchNormInference: CHECK_EQ(new_operands.size(), 5); - return CreateBatchNormInference( + clone = CreateBatchNormInference( shape, new_operands[0], new_operands[1], new_operands[2], new_operands[3], new_operands[4], epsilon(), feature_index()); + break; case HloOpcode::kInfeed: CHECK_EQ(new_operands.size(), 0); - return CreateInfeed(shape, infeed_config()); + clone = CreateInfeed(shape, infeed_config()); + break; case HloOpcode::kOutfeed: CHECK_EQ(new_operands.size(), 1); - return CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); + clone = CreateOutfeed(outfeed_shape_, new_operands[0], outfeed_config()); + break; case HloOpcode::kBatchNormGrad: CHECK_EQ(new_operands.size(), 5); - return CreateBatchNormGrad(shape, new_operands[0], new_operands[1], - new_operands[2], new_operands[3], - new_operands[4], epsilon(), feature_index()); + clone = CreateBatchNormGrad(shape, new_operands[0], new_operands[1], + new_operands[2], new_operands[3], + new_operands[4], epsilon(), feature_index()); + break; case HloOpcode::kRecv: case HloOpcode::kSend: - case HloOpcode::kUpdate: - case HloOpcode::kIndex: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } + clone->set_metadata(metadata_); + if (has_sharding()) { + clone->set_sharding(sharding()); + } + clone->set_parent(parent_); + return clone; } HloInstruction::~HloInstruction() {} -std::unique_ptr HloInstruction::Clone(const string& suffix) { +std::unique_ptr HloInstruction::Clone(const string& suffix, + HloModule* module) const { std::unique_ptr clone = - CloneWithNewOperands(shape_, operands_); + CloneWithNewOperands(shape_, operands_, module); if (suffix.empty()) { clone->name_ = name(); } else { @@ -1062,13 +1215,12 @@ std::unique_ptr HloInstruction::Clone(const string& suffix) { } } } - clone->set_parent(parent()); - clone->set_metadata(metadata_); return clone; } std::unique_ptr HloInstruction::CloneFusionWithNewOperands( - const Shape& shape, tensorflow::gtl::ArraySlice operands) { + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module) const { CHECK_EQ(opcode_, HloOpcode::kFusion); CHECK(parent() != nullptr); @@ -1079,13 +1231,14 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( new_instruction->AppendOperand(new_operand); } // Clone all the fused instructions for the new fusion instruction. - std::map old_to_new; + HloInstructionMap old_to_new; std::list> new_fused_instructions; // Create the list of fused parameters by mapping through the cloned, // fused instructions. for (HloInstruction* old_fused_parameter : fused_instructions_computation()->parameter_instructions()) { - new_fused_instructions.push_back(old_fused_parameter->Clone()); + new_fused_instructions.push_back( + old_fused_parameter->Clone("clone", module)); HloInstruction* new_fusion_parameter = new_fused_instructions.back().get(); InsertOrDie(&old_to_new, old_fused_parameter, new_fusion_parameter); } @@ -1104,9 +1257,9 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( } new_fused_instructions.push_back( old_fused_instruction->CloneWithNewOperands( - old_fused_instruction->shape(), new_operands)); + old_fused_instruction->shape(), new_operands, module)); HloInstruction* new_fused_instruction = new_fused_instructions.back().get(); - new_fused_instruction->set_parent(parent()); + new_fused_instruction->set_parent(parent_); InsertOrDie(&old_to_new, old_fused_instruction, new_fused_instruction); } new_instruction->fusion_kind_ = fusion_kind_; @@ -1120,15 +1273,39 @@ std::unique_ptr HloInstruction::CloneFusionWithNewOperands( ++new_fused_instruction_iter) { computation_builder.AddInstruction(std::move(*new_fused_instruction_iter)); } + if (module == nullptr) { + module = GetModule(); + } auto fused_root_ = fused_expression_root(); new_instruction->called_computations_.push_back( - CHECK_NOTNULL(GetModule()) - ->AddEmbeddedComputation( - computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); - new_instruction->set_parent(parent()); + CHECK_NOTNULL(module)->AddEmbeddedComputation( + computation_builder.Build(FindOrDie(old_to_new, fused_root_)))); return new_instruction; } +std::pair +HloInstruction::LatestNonGteAncestorAndIndex() const { + const HloInstruction* hlo = this; + ShapeIndex index; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + index.push_back(hlo->tuple_index()); + hlo = hlo->operand(0); + } + + // We built up index in the reverse order from what we want. + std::reverse(index.begin(), index.end()); + + return {hlo, index}; +} + +const HloInstruction* HloInstruction::LatestNonGteAncestor() const { + const HloInstruction* hlo = this; + while (hlo->opcode() == HloOpcode::kGetTupleElement) { + hlo = hlo->operand(0); + } + return hlo; +} + const Literal& HloInstruction::literal() const { CHECK_EQ(HloOpcode::kConstant, opcode_); return *literal_; @@ -1239,10 +1416,12 @@ bool HloInstruction::IdenticalSlowPath( // The result of these instructions only depend upon their opcode and // operands. case HloOpcode::kAbs: + case HloOpcode::kAtan2: case HloOpcode::kRoundNearestAfz: case HloOpcode::kAdd: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kCopy: case HloOpcode::kCos: case HloOpcode::kCrossReplicaSum: @@ -1253,12 +1432,13 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kFloor: case HloOpcode::kGe: case HloOpcode::kGt: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLe: case HloOpcode::kLog: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -1266,8 +1446,12 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kNe: case HloOpcode::kNegate: case HloOpcode::kPower: + case HloOpcode::kReal: case HloOpcode::kRemainder: case HloOpcode::kSelect: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kSubtract: @@ -1354,7 +1538,8 @@ bool HloInstruction::IdenticalSlowPath( other.padding_config()); case HloOpcode::kSlice: return slice_starts_ == other.slice_starts_ && - slice_limits_ == other.slice_limits_; + slice_limits_ == other.slice_limits_ && + slice_strides_ == other.slice_strides_; case HloOpcode::kDynamicSlice: return ShapeUtil::Compatible(shape(), other.shape()) && dynamic_slice_sizes_ == other.dynamic_slice_sizes_; @@ -1369,11 +1554,9 @@ bool HloInstruction::IdenticalSlowPath( return dimensions() == other.dimensions(); // These opcodes are not yet supported. - case HloOpcode::kIndex: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kUpdate: case HloOpcode::kSend: case HloOpcode::kRecv: return false; @@ -1461,6 +1644,9 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { if (new_producer_is_user) { AddUser(new_producer); } + if (parent_ && parent_->root_instruction() == this) { + parent_->set_root_instruction(new_producer); + } return Status::OK(); } @@ -1592,11 +1778,12 @@ string HloInstruction::ExtendedOpcodeStr() const { return opc_name; } -string HloInstruction::ToString(bool compact_operands, - bool include_metadata) const { +string HloInstruction::ToString(bool compact_operands, bool include_metadata, + bool include_large_constants) const { string result = StrCat(name(), " = ", ShapeUtil::HumanStringWithLayout(shape()), " ", - ExtendedOpcodeStr(), "(", OperandsToString(compact_operands), ")"); + ExtendedOpcodeStr(), "(", + OperandsToString(compact_operands, include_large_constants), ")"); for (const string& extra : ExtraAttributesToString()) { StrAppend(&result, ", ", extra); } @@ -1608,11 +1795,14 @@ string HloInstruction::ToString(bool compact_operands, return result; } -string HloInstruction::OperandsToString(bool compact) const { +string HloInstruction::OperandsToString(bool compact, + bool include_large_constants) const { string operands; if (opcode() == HloOpcode::kConstant) { // For constants, show the actual value in place of an empty operand list. - if (!ShapeUtil::IsTuple(shape()) && ShapeUtil::ElementsIn(shape()) <= 10) { + if ((!ShapeUtil::IsTuple(shape()) && + ShapeUtil::ElementsIn(shape()) <= 10) || + include_large_constants) { // Literal::ToString emits multidimensional arrays over multiple // lines. Compact this into one line by stripping out white space. string tmp = literal().ToString(); @@ -1684,28 +1874,40 @@ std::vector HloInstruction::ExtraAttributesToString() const { } if (opcode() == HloOpcode::kWhile) { - extra.push_back(StrCat("condition=", while_condition()->name())); - extra.push_back(StrCat("body=", while_body()->name())); + extra.push_back(StrCat("condition=%", while_condition()->name())); + extra.push_back(StrCat("body=%", while_body()->name())); } else if (opcode() == HloOpcode::kSelectAndScatter) { - extra.push_back(StrCat("select=", select()->name())); - extra.push_back(StrCat("scatter=", scatter()->name())); + extra.push_back(StrCat("select=%", select()->name())); + extra.push_back(StrCat("scatter=%", scatter()->name())); + } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || + opcode() == HloOpcode::kReduceWindow || + opcode() == HloOpcode::kReduce) { + extra.push_back(StrCat("to_apply=%", to_apply()->name())); } else if (!called_computations().empty()) { extra.push_back(StrCat( "calls=", Join(called_computations(), ", ", [](string* out, const HloComputation* computation) { - StrAppend(out, computation->name()); + StrAppend(out, "%", computation->name()); }))); } + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { + extra.push_back(StrCat("channel_id=", channel_id_)); + } + if (opcode() == HloOpcode::kGetTupleElement) { extra.push_back(StrCat("index=", tuple_index())); } - if (!control_successors_.empty()) { - extra.push_back(StrCat( - "control-successors=", - Join(control_successors_, ", ", [](string* out, HloInstruction* succ) { - StrAppend(out, succ->name()); - }))); + if (has_sharding()) { + extra.push_back(StrCat("sharding=", sharding().ToString())); + } + if (!control_predecessors_.empty()) { + extra.push_back(StrCat("control-predecessors={", + Join(control_predecessors_, ", ", + [](string* out, HloInstruction* pre) { + StrAppend(out, pre->name()); + }), + "}")); } return extra; } @@ -1730,37 +1932,59 @@ HloInstructionProto HloInstruction::ToProto() const { for (const HloInstruction* control : control_predecessors_) { *proto.add_control_predecessor_names() = control->name(); } - for (const HloComputation* computation : called_computations_) { - *proto.add_called_computation_names() = computation->name(); - } + *proto.mutable_metadata() = metadata_; - switch (opcode_) { - case HloOpcode::kConstant: - *proto.mutable_literal() = literal_->ToProto(); - break; - case HloOpcode::kParameter: - proto.set_parameter_number(parameter_number_); - proto.set_parameter_name(parameter_name_); - break; - case HloOpcode::kFusion: { - HloComputationProto* proto_fused_computation = - proto.mutable_fused_instructions_computation(); - proto_fused_computation->set_name(name()); - - // Fill in fused instructions in post order. - auto fused_instructions = - fused_instructions_computation()->MakeInstructionPostOrder(); - for (auto fused_instruction : fused_instructions) { - HloInstructionProto fused_proto = fused_instruction->ToProto(); - proto_fused_computation->add_instructions()->Swap(&fused_proto); - } - break; + if (literal_ != nullptr) { + *proto.mutable_literal() = literal_->ToProto(); + } + proto.set_parameter_number(parameter_number_); + proto.set_parameter_name(parameter_name_); + if (opcode() == HloOpcode::kFusion) { + proto.set_fusion_kind(xla::ToString(fusion_kind())); + *proto.mutable_fused_instructions_computation() = + fused_instructions_computation()->ToProto(); + } else { + for (const HloComputation* computation : called_computations_) { + *proto.add_called_computation_names() = computation->name(); } - case HloOpcode::kGetTupleElement: - proto.set_tuple_index(tuple_index_); - break; - default: {} // Nothing to do } + + proto.set_tuple_index(tuple_index_); + for (int64 dimension : dimensions_) { + proto.add_dimensions(dimension); + } + if (window_ != nullptr) { + *proto.mutable_window() = *window_; + } + if (convolution_dimension_numbers_ != nullptr) { + *proto.mutable_convolution_dimension_numbers() = + *convolution_dimension_numbers_; + } + for (int i = 0; i < slice_starts_.size(); ++i) { + auto* slice_dimension = proto.add_slice_dimensions(); + slice_dimension->set_start(slice_starts_[i]); + slice_dimension->set_limit(slice_limits_[i]); + slice_dimension->set_stride(slice_strides_[i]); + } + proto.set_exponent_bits(exponent_bits_); + proto.set_mantissa_bits(mantissa_bits_); + for (int64 slice_size : dynamic_slice_sizes_) { + proto.add_dynamic_slice_sizes(slice_size); + } + if (padding_config_ != nullptr) { + *proto.mutable_padding_config() = *padding_config_; + } + proto.set_outfeed_config(outfeed_config_); + if (opcode() == HloOpcode::kRng) { + proto.set_distribution(distribution_); + } + proto.set_epsilon(epsilon_); + proto.set_feature_index(feature_index_); + proto.set_channel_id(channel_id_); + proto.set_infeed_config(infeed_config_); + proto.set_custom_call_target(custom_call_target_); + *proto.mutable_outfeed_shape() = outfeed_shape_; + return proto; } @@ -1884,12 +2108,25 @@ const std::vector& HloInstruction::fused_parameters() const { return fused_instructions_computation()->parameter_instructions(); } -const std::list>& +const tensorflow::gtl::iterator_range>::const_iterator>> HloInstruction::fused_instructions() const { + CHECK_EQ(opcode_, HloOpcode::kFusion); + const HloComputation* subcomp = fused_instructions_computation(); + return subcomp->instructions(); +} + +const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> +HloInstruction::fused_instructions() { CHECK_EQ(opcode_, HloOpcode::kFusion); return fused_instructions_computation()->instructions(); } +int64 HloInstruction::fused_instruction_count() const { + return fused_instructions_computation()->instruction_count(); +} + HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) : unique_id_(-1), opcode_(opcode), @@ -1898,10 +2135,13 @@ HloInstruction::HloInstruction(HloOpcode opcode, const Shape& shape) TF_DCHECK_OK(ShapeUtil::ValidateShapeWithOptionalLayout(shape_)); } -Status HloInstruction::Visit(DfsHloVisitor* visitor) { +template +Status HloInstruction::Visit(DfsHloVisitorBase* visitor) { switch (opcode_) { case HloOpcode::kAbs: - return visitor->HandleAbs(this, operands_[0]); + return visitor->HandleAbs(this); + case HloOpcode::kAtan2: + return visitor->HandleAtan2(this); case HloOpcode::kRoundNearestAfz: return visitor->HandleRound(this); case HloOpcode::kBatchNormTraining: @@ -1911,11 +2151,11 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kBatchNormGrad: return visitor->HandleBatchNormGrad(this); case HloOpcode::kSign: - return visitor->HandleSign(this, operands_[0]); + return visitor->HandleSign(this); case HloOpcode::kConstant: - return visitor->HandleConstant(this, *literal_); + return visitor->HandleConstant(this); case HloOpcode::kGetTupleElement: - return visitor->HandleGetTupleElement(this, operands_[0]); + return visitor->HandleGetTupleElement(this); case HloOpcode::kParameter: return visitor->HandleParameter(this); case HloOpcode::kEq: @@ -1924,78 +2164,85 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kLe: case HloOpcode::kLt: case HloOpcode::kNe: - return visitor->HandleCompare(this, opcode_, operands_[0], operands_[1]); + return visitor->HandleCompare(this); + case HloOpcode::kComplex: + return visitor->HandleComplex(this); case HloOpcode::kAdd: - return visitor->HandleAdd(this, operands_[0], operands_[1]); + return visitor->HandleAdd(this); case HloOpcode::kDivide: - return visitor->HandleDivide(this, operands_[0], operands_[1]); + return visitor->HandleDivide(this); case HloOpcode::kSubtract: - return visitor->HandleSubtract(this, operands_[0], operands_[1]); + return visitor->HandleSubtract(this); case HloOpcode::kMaximum: return visitor->HandleMaximum(this); case HloOpcode::kMinimum: return visitor->HandleMinimum(this); - case HloOpcode::kLogicalAnd: - return visitor->HandleLogicalAnd(this, operands_[0], operands_[1]); - case HloOpcode::kLogicalOr: - return visitor->HandleLogicalOr(this, operands_[0], operands_[1]); + case HloOpcode::kAnd: + return visitor->HandleAnd(this); + case HloOpcode::kOr: + return visitor->HandleOr(this); + case HloOpcode::kShiftLeft: + return visitor->HandleShiftLeft(this); + case HloOpcode::kShiftRightArithmetic: + return visitor->HandleShiftRightArithmetic(this); + case HloOpcode::kShiftRightLogical: + return visitor->HandleShiftRightLogical(this); case HloOpcode::kConcatenate: - return visitor->HandleConcatenate(this, operands_); + return visitor->HandleConcatenate(this); case HloOpcode::kConvert: return visitor->HandleConvert(this); case HloOpcode::kCopy: return visitor->HandleCopy(this); case HloOpcode::kMultiply: - return visitor->HandleMultiply(this, operands_[0], operands_[1]); + return visitor->HandleMultiply(this); case HloOpcode::kDot: - return visitor->HandleDot(this, operands_[0], operands_[1]); + return visitor->HandleDot(this); case HloOpcode::kPower: - return visitor->HandlePower(this, operands_[0], operands_[1]); + return visitor->HandlePower(this); case HloOpcode::kRemainder: - return visitor->HandleRemainder(this, operands_[0], operands_[1]); + return visitor->HandleRemainder(this); case HloOpcode::kSelect: - return visitor->HandleSelect(this, operands_[0], operands_[1], - operands_[2]); + return visitor->HandleSelect(this); case HloOpcode::kConvolution: - return visitor->HandleConvolution(this, operands_[0], operands_[1], - window()); + return visitor->HandleConvolution(this); case HloOpcode::kCrossReplicaSum: return visitor->HandleCrossReplicaSum(this); case HloOpcode::kTuple: - return visitor->HandleTuple(this, operands_); + return visitor->HandleTuple(this); case HloOpcode::kMap: - return visitor->HandleMap(this, operands_, to_apply(), {}); + return visitor->HandleMap(this); case HloOpcode::kClamp: - return visitor->HandleClamp(this, operands_[0], operands_[1], - operands_[2]); + return visitor->HandleClamp(this); case HloOpcode::kReduce: - return visitor->HandleReduce(this, operands_[0], operands_[1], - dimensions_, to_apply()); + return visitor->HandleReduce(this); case HloOpcode::kReduceWindow: - return visitor->HandleReduceWindow(this, operands_[0], window(), - to_apply()); + return visitor->HandleReduceWindow(this); case HloOpcode::kSelectAndScatter: return visitor->HandleSelectAndScatter(this); case HloOpcode::kNegate: - return visitor->HandleNegate(this, operands_[0]); + return visitor->HandleNegate(this); case HloOpcode::kExp: - return visitor->HandleExp(this, operands_[0]); + return visitor->HandleExp(this); case HloOpcode::kFloor: - return visitor->HandleFloor(this, operands_[0]); + return visitor->HandleFloor(this); case HloOpcode::kCeil: - return visitor->HandleCeil(this, operands_[0]); + return visitor->HandleCeil(this); case HloOpcode::kLog: - return visitor->HandleLog(this, operands_[0]); + return visitor->HandleLog(this); case HloOpcode::kTanh: - return visitor->HandleTanh(this, operands_[0]); + return visitor->HandleTanh(this); case HloOpcode::kCos: - return visitor->HandleCos(this, operands_[0]); + return visitor->HandleCos(this); case HloOpcode::kSin: - return visitor->HandleSin(this, operands_[0]); + return visitor->HandleSin(this); + case HloOpcode::kReal: + return visitor->HandleReal(this); + case HloOpcode::kImag: + return visitor->HandleImag(this); case HloOpcode::kIsFinite: - return visitor->HandleIsFinite(this, operands_[0]); - case HloOpcode::kLogicalNot: - return visitor->HandleLogicalNot(this, operands_[0]); + return visitor->HandleIsFinite(this); + case HloOpcode::kNot: + return visitor->HandleNot(this); case HloOpcode::kBitcast: return visitor->HandleBitcast(this); case HloOpcode::kBroadcast: @@ -2007,24 +2254,23 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kTranspose: return visitor->HandleTranspose(this); case HloOpcode::kReverse: - return visitor->HandleReverse(this, operands_[0]); + return visitor->HandleReverse(this); case HloOpcode::kReducePrecision: return visitor->HandleReducePrecision(this); case HloOpcode::kSlice: - return visitor->HandleSlice(this, operands_[0]); + return visitor->HandleSlice(this); case HloOpcode::kDynamicSlice: - return visitor->HandleDynamicSlice(this, operands_[0], operands_[1]); + return visitor->HandleDynamicSlice(this); case HloOpcode::kDynamicUpdateSlice: - return visitor->HandleDynamicUpdateSlice(this, operands_[0], operands_[1], - operands_[2]); + return visitor->HandleDynamicUpdateSlice(this); case HloOpcode::kSort: - return visitor->HandleSort(this, operands_[0]); + return visitor->HandleSort(this); case HloOpcode::kInfeed: return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); case HloOpcode::kRng: - return visitor->HandleRng(this, distribution_); + return visitor->HandleRng(this); case HloOpcode::kWhile: return visitor->HandleWhile(this); case HloOpcode::kFusion: @@ -2032,41 +2278,44 @@ Status HloInstruction::Visit(DfsHloVisitor* visitor) { case HloOpcode::kCall: return visitor->HandleCall(this); case HloOpcode::kCustomCall: - return visitor->HandleCustomCall(this, operands_, custom_call_target_); + return visitor->HandleCustomCall(this); case HloOpcode::kSend: return visitor->HandleSend(this); case HloOpcode::kRecv: return visitor->HandleRecv(this); // These opcodes are not handled here. - case HloOpcode::kIndex: case HloOpcode::kTrace: - case HloOpcode::kUpdate: break; } return Unimplemented("unhandled HloOpcode for DfsHloVisitor: %s", HloOpcodeString(opcode_).c_str()); } +// Explicit instantiations. +template Status HloInstruction::Visit(DfsHloVisitor* visitor); +template Status HloInstruction::Visit(ConstDfsHloVisitor* visitor); + using DFSStack = tensorflow::gtl::InlinedVector, 16>; // Push "child" onto the dfs_stack if not already visited. Returns false if a // cycle was detected, and true otherwise. -inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, +template +inline bool PushDFSChild(Visitor* visitor, DFSStack* dfs_stack, HloInstruction* child) { CHECK(child != nullptr); const int id = child->unique_id(); CHECK_GE(id, 0) << "instruction may not have a parent computation"; switch (visitor->GetVisitState(id)) { - case DfsHloVisitor::kVisiting: + case Visitor::kVisiting: return false; - case DfsHloVisitor::kVisited: + case Visitor::kVisited: // Nothing to do return true; - case DfsHloVisitor::kNotVisited: + case Visitor::kNotVisited: dfs_stack->push_back(std::make_pair(id, child)); return true; } @@ -2075,7 +2324,8 @@ inline bool PushDFSChild(DfsHloVisitor* visitor, DFSStack* dfs_stack, using InternalCompareFunction = std::function, std::pair)>; -static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, +template +static Status PostOrderDFS(HloInstruction* root, Visitor* visitor, const InternalCompareFunction* operand_order, bool ignore_control_predecessors) { visitor->ReserveVisitStates(root->GetModule()->NumUniqueInstructionIds()); @@ -2084,7 +2334,7 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, // // We need to keep track of both the id and the instruction because // instructions can get deleted while they are on the stack, so we - // can't always use the (potentiall dead) instruction object to grab + // can't always use the (potentially dead) instruction object to grab // its id. DFSStack dfs_stack; dfs_stack.emplace_back(root->unique_id(), root); @@ -2096,26 +2346,27 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, HloInstruction* current_node = dfs_stack.back().second; CHECK_GE(current_id, 0) << current_id << ": " << current_node << ": instruction may not have parent computation"; - DfsHloVisitor::VisitState visit_state = visitor->GetVisitState(current_id); - if (visit_state == DfsHloVisitor::kVisited) { + typename Visitor::VisitState visit_state = + visitor->GetVisitState(current_id); + if (visit_state == Visitor::kVisited) { dfs_stack.pop_back(); VLOG(3) << "Not visiting HLO " << current_node->name() << " as it was already visited."; continue; } - if (visit_state == DfsHloVisitor::kVisiting) { + if (visit_state == Visitor::kVisiting) { dfs_stack.pop_back(); TF_RETURN_IF_ERROR(visitor->Preprocess(current_node)); VLOG(2) << "Visiting HLO " << current_node->name(); TF_RETURN_IF_ERROR(current_node->Visit(visitor)); - visitor->SetVisitState(current_id, DfsHloVisitor::kVisited); + visitor->SetVisitState(current_id, Visitor::kVisited); TF_RETURN_IF_ERROR(visitor->Postprocess(current_node)); continue; } - visitor->SetVisitState(current_id, DfsHloVisitor::kVisiting); + visitor->SetVisitState(current_id, Visitor::kVisiting); const size_t old_dfs_stack_size = dfs_stack.size(); for (HloInstruction* child : current_node->operands()) { @@ -2149,7 +2400,9 @@ static Status PostOrderDFS(HloInstruction* root, DfsHloVisitor* visitor, return Status::OK(); } -Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, +template +Status HloInstruction::Accept(DfsHloVisitorBase* visitor, + bool call_finish_visit, bool ignore_control_predecessors) { VLOG(3) << "HloInstruction::Accept(" << name() << ")"; TF_RETURN_IF_ERROR( @@ -2160,6 +2413,10 @@ Status HloInstruction::Accept(DfsHloVisitor* visitor, bool call_finish_visit, return Status::OK(); } +// Explicit instantiations. +template Status HloInstruction::Accept(DfsHloVisitor*, bool, bool); +template Status HloInstruction::Accept(ConstDfsHloVisitor*, bool, bool); + Status HloInstruction::AcceptWithOperandOrder( DfsHloVisitor* visitor, const CompareFunction& operand_order, bool call_finish_visit) { @@ -2213,11 +2470,17 @@ bool OrderIsTopologicalSort(const std::vector& order) { } // namespace Status HloInstruction::Accept( - const FunctionVisitor::VisitorFunction& visitor_func) { + const std::function& visitor_func) { FunctionVisitor visitor(visitor_func); return this->Accept(&visitor); } +Status HloInstruction::Accept( + const std::function& visitor_func) const { + ConstFunctionVisitor visitor(visitor_func); + return this->Accept(&visitor); +} + Status HloInstruction::AcceptOrdered( DfsHloVisitor* visitor, const std::vector& order) { VLOG(2) << "HloInstruction::AcceptOrdered(" << name() << ")"; @@ -2280,29 +2543,7 @@ std::vector HloInstruction::OperandIndices( } bool HloInstruction::IsElementwiseBinary() const { - switch (opcode_) { - // Binary elementwise operations. If you update this, please update - // IsElementwise() accordingly. - case HloOpcode::kAdd: - case HloOpcode::kDivide: - case HloOpcode::kEq: - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kMaximum: - case HloOpcode::kMinimum: - case HloOpcode::kMultiply: - case HloOpcode::kNe: - case HloOpcode::kPower: - case HloOpcode::kRemainder: - case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: - return true; - default: - return false; - } + return IsElementwise() && operand_count() == 2; } bool HloInstruction::IsElementwise() const { @@ -2320,19 +2561,23 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kCos: case HloOpcode::kExp: case HloOpcode::kFloor: + case HloOpcode::kImag: case HloOpcode::kIsFinite: case HloOpcode::kLog: - case HloOpcode::kLogicalNot: + case HloOpcode::kNot: case HloOpcode::kNegate: + case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kSign: case HloOpcode::kSin: case HloOpcode::kTanh: + CHECK_EQ(1, operand_count()); return true; // Binary elementwise operations, the same as in IsElementwiseBinary(). - // If you update this, please update IsElementwiseBinary() accordingly. case HloOpcode::kAdd: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: case HloOpcode::kDivide: case HloOpcode::kEq: case HloOpcode::kGe: @@ -2346,8 +2591,12 @@ bool HloInstruction::IsElementwise() const { case HloOpcode::kPower: case HloOpcode::kRemainder: case HloOpcode::kSubtract: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: + CHECK_EQ(2, operand_count()); return true; // Ternary elementwise operations. @@ -2364,7 +2613,7 @@ bool HloInstruction::IsElementwise() const { if (fusion_kind() != FusionKind::kLoop) { return false; } - for (auto& fused : fused_instructions()) { + for (auto* fused : fused_instructions()) { if (fused->opcode() != HloOpcode::kParameter && !fused->IsElementwise()) { return false; @@ -2377,6 +2626,11 @@ bool HloInstruction::IsElementwise() const { } } +bool HloInstruction::ImplicitlyBroadcastsOperand(int64 operand_idx) const { + CHECK(IsElementwise()); + return !ShapeUtil::Equal(shape(), operand(operand_idx)->shape()); +} + namespace { bool IsInstructionElementwiseOnOperand(const HloInstruction* instruction, const HloInstruction* operand) { @@ -2433,10 +2687,10 @@ class HloInstruction::FusionReusesParamElements { public: using UseKind = HloInstruction::UseKind; - // We could rather iterate backwards thru fused_instructions_ here, as it is - // in reverse postorder, and compute whether each fused instruction reuses - // the value of this parameter, which would save stack space but not allow - // us to finish early if we find a reuse. + // We could rather iterate backwards through fused_instructions_ here, as it + // is in reverse postorder, and compute whether each fused instruction reuses + // the value of this parameter, which would save stack space but not allow us + // to finish early if we find a reuse. static UseKind Compute(int64 i, const HloInstruction& hlo) { tensorflow::gtl::FlatMap memoization_cache; return ComputeInternal(i, hlo, &memoization_cache); @@ -2527,7 +2781,9 @@ HloInstruction::UseKind HloInstruction::OperandElementUse(int64 i) const { } return UseKind::kReuse; default: - return IsElementwise() ? UseKind::kUse : UseKind::kReuse; + return IsElementwise() && !ImplicitlyBroadcastsOperand(i) + ? UseKind::kUse + : UseKind::kReuse; } } @@ -2559,6 +2815,32 @@ string ToString(HloInstruction::FusionKind kind) { } } +StatusOr StringToFusionKind( + const string& kind_name) { + if (kind_name == "kLoop") { + return HloInstruction::FusionKind::kLoop; + } + if (kind_name == "kInput") { + return HloInstruction::FusionKind::kInput; + } + if (kind_name == "kOutput") { + return HloInstruction::FusionKind::kOutput; + } + if (kind_name == "kTransposeDot") { + return HloInstruction::FusionKind::kTransposeDot; + } + if (kind_name == "kConvBackwardFilter") { + return HloInstruction::FusionKind::kConvBackwardFilter; + } + if (kind_name == "kConvBackwardInput") { + return HloInstruction::FusionKind::kConvBackwardInput; + } + if (kind_name == "kCustom") { + return HloInstruction::FusionKind::kCustom; + } + return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2586,8 +2868,8 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { // lhs_dims[i] is the symbol of the logical dimension i for the lhs // operand. E.g. if batch has dimension number 2, then lhs_dims[2] == "b". std::vector lhs_dims(2 + dnums.spatial_dimensions().size()); - lhs_dims[dnums.batch_dimension()] = 'b'; - lhs_dims[dnums.feature_dimension()] = 'f'; + lhs_dims[dnums.input_batch_dimension()] = 'b'; + lhs_dims[dnums.input_feature_dimension()] = 'f'; for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { lhs_dims[dnums.spatial_dimensions(i)] = StrCat(i); } @@ -2599,12 +2881,19 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { rhs_dims[dnums.kernel_spatial_dimensions(i)] = StrCat(i); } + std::vector output_dims(2 + dnums.spatial_dimensions().size()); + output_dims[dnums.output_batch_dimension()] = 'b'; + output_dims[dnums.output_feature_dimension()] = 'f'; + for (int64 i = 0; i < dnums.spatial_dimensions().size(); ++i) { + output_dims[dnums.spatial_dimensions(i)] = StrCat(i); + } + result += "dim_labels="; append_dims(lhs_dims, operand(0)->shape()); result += "_"; append_dims(rhs_dims, operand(1)->shape()); result += "->"; - append_dims(lhs_dims, shape()); + append_dims(output_dims, shape()); return result; } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 9b42f1756d0f15da238946efe07a868c4c1c3dca..5ff04a48882497ef546aa095c346f4318a61f02b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -27,15 +27,17 @@ limitations under the License. #include #include #include +#include #include #include +#include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/map_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" -#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_sharding.h" #include "tensorflow/compiler/xla/service/name_uniquer.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" @@ -43,6 +45,7 @@ limitations under the License. #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/inlined_vector.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" @@ -70,6 +73,23 @@ class HloInstruction { }; ~HloInstruction(); + + // Creates an instruction from the given proto. Arguments: + // + // module: the module which will contain the instruction. The newly created + // instruction is *not* added to the module or any computation, however. + // proto: the proto to convert from. + // instruction_map: a map from instruction name to HloInstruction*. This map + // must contain all operands of the newly constructed instruction. + // computation_map: a map from computation name to HloComputation*. This map + // must contain all computations which the newly constructed instruction + // calls. If the instruction is a fusion instruction, then the fusion + // computation is added to this map and the module. + static StatusOr> CreateFromProto( + HloModule* module, const HloInstructionProto& proto, + const tensorflow::gtl::FlatMap& instruction_map, + tensorflow::gtl::FlatMap* computation_map); + // Creates a parameter-retrieving instruction. static std::unique_ptr CreateParameter(int64 parameter_number, const Shape& shape, @@ -182,7 +202,7 @@ class HloInstruction { tensorflow::gtl::ArraySlice strides); // Creates a slice instruction, where the first operand is sliced by - // start indices specified in the second operand, and by size specfied in + // start indices specified in the second operand, and by size specified in // 'slice_sizes'. static std::unique_ptr CreateDynamicSlice( const Shape& shape, HloInstruction* operand, @@ -422,6 +442,9 @@ class HloInstruction { // Replaces all uses of this instruction with the new producer. If // new_producer is a user of this instruction then new_producer remains a use // of this instruction to avoid introducing cycles into the graph. + // + // If this instruction is the root of its computation, sets the computation's + // root to new_producer. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Detaches an instruction from its operands. That is, remove the instruction @@ -435,8 +458,15 @@ class HloInstruction { // reachable via control dependencies will not be visited, and the postorder // will not take control dependencies into account. It is as if the control // dependencies didn't exist in the graph at all. - Status Accept(DfsHloVisitor* visitor, bool call_finish_visit = true, + template + Status Accept(DfsHloVisitorBase* visitor, + bool call_finish_visit = true, bool ignore_control_predecessors = false); + Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true, + bool ignore_control_predecessors = false) const { + return const_cast(this)->Accept( + visitor, call_finish_visit, ignore_control_predecessors); + } // Same as Accept() above, but the order of operand and control predecessor // visitation is determined by the given operand order; if compare(A, B) == @@ -449,7 +479,9 @@ class HloInstruction { // Performs a postorder DFS visit using this node as the root. Calls the given // visitor function at each instruction. - Status Accept(const FunctionVisitor::VisitorFunction& visitor_func); + Status Accept(const std::function& visitor_func); + Status Accept( + const std::function& visitor_func) const; // Visits all instructions rooted at this instruction using the given visitor // in the given order. 'order' must contain at least the set of instructions @@ -462,7 +494,8 @@ class HloInstruction { const std::vector& order); // Visit this instruction and only this instruction with the given visitor. - Status Visit(DfsHloVisitor* visitor); + template + Status Visit(DfsHloVisitorBase* visitor); // Returns the literal associated with this instruction. // @@ -503,6 +536,26 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kGetTupleElement int64 tuple_index() const; + // Returns the first non-GetTupleElement ancestor instruction of 'hlo'. + // If the first non-GTE ancestor is tuple-shaped, populates 'index' with the + // (possibly nested) tuple indices used on the path from ancestor to 'hlo'. + std::pair LatestNonGteAncestorAndIndex() + const; + + std::pair LatestNonGteAncestorAndIndex() { + auto rv = + const_cast(this)->LatestNonGteAncestorAndIndex(); + return {const_cast(rv.first), rv.second}; + } + + // Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction. + const HloInstruction* LatestNonGteAncestor() const; + + HloInstruction* LatestNonGteAncestor() { + return const_cast( + const_cast(this)->LatestNonGteAncestor()); + } + // Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc. // The setter should only be called by HloModule or HloComputation methods. // @@ -545,13 +598,13 @@ class HloInstruction { string SignatureString() const; // Returns a debugging string that represents this instruction. - string ToString(bool compact_operands = false, - bool include_metadata = true) const; + string ToString(bool compact_operands = false, bool include_metadata = true, + bool include_large_constants = false) const; // Components of the ToString() representation: // Returns a string representation of the operand list. - string OperandsToString(bool compact) const; + string OperandsToString(bool compact, bool include_large_constants) const; // Returns string representation of op-specific attributes. std::vector ExtraAttributesToString() const; @@ -626,13 +679,22 @@ class HloInstruction { // Precondition: opcode() == HloOpcode::kFusion HloInstruction* fused_expression_root() const; - // Returns the list of fused instructions inside this fusioninstruction. + // Returns the list of fused instructions inside this fusion instruction. The + // returned type is a range of HloInstruction*s. // - // Note: although the list itself is const, the instructions contained in the - // list returned here are mutable. + // Precondition: opcode() == HloOpcode::kFusion + const tensorflow::gtl::iterator_range>::const_iterator>> + fused_instructions() const; + + const tensorflow::gtl::iterator_range< + UnwrappingIterator>::iterator>> + fused_instructions(); + + // Gets the number of instructions inside this fusion instruction. // // Precondition: opcode() == HloOpcode::kFusion - const std::list>& fused_instructions() const; + int64 fused_instruction_count() const; // Returns the fused parameter instruction in this fusion instruction // corresponding to the given parameter number. @@ -662,6 +724,26 @@ class HloInstruction { fusion_kind_ = kind; } + // Returns the sharding applied to this operator. + // REQUIRES: has_sharding() is true. + const HloSharding& sharding() const { + CHECK(has_sharding()); + return *sharding_; + } + // Returns the sharding applied to this operator, or default_ if none exists. + const HloSharding& sharding_or_default(const HloSharding& default_) const { + return sharding_ ? *sharding_ : default_; + } + // Sets the sharding of this operator. Should only be called by HloModule or + // HloComputation methods. + void set_sharding(const HloSharding& sharding) { + sharding_ = MakeUnique(sharding); + } + // Remove any sharding from this operator. + void clear_sharding() { sharding_ = nullptr; } + // Return true if this operator has a sharding assigned. + bool has_sharding() const { return sharding_ != nullptr; } + // Merges the fused instructions from 'instruction_to_merge' into the // fused instruction set of 'this', updating operands as necessary. // @@ -669,11 +751,11 @@ class HloInstruction { // Predondition: 'instruction_to_merge' must be an operand of 'this'. void MergeFusionInstruction(HloInstruction* instruction_to_merge); - // Merges the fused instructions from 'instruction_to_merge' into the - // fused instruction set of 'this' and generate multioutput fusion - // instructions. All the user of instruction_to_merge will be redirected - // to 'this' instruction. `instruction_to_merge' will be removed from its - // parent computation. + // Merges the fused instructions from instruction_to_merge into the fused + // instruction set of 'this' and generates multioutput fusion instructions. + // All the users of instruction_to_merge will be redirected to 'this' + // instruction. instruction_to_merge will be removed from its parent + // computation. // // Precondition: opcode() == HloOpcode::kFusion void MergeFusionInstructionIntoMultiOutput( @@ -798,12 +880,19 @@ class HloInstruction { // operands. After creation the clone has no uses. "this" (the instruction // cloned from) is not changed. Suffix is the string to append to the name of // the instruction to form the name of the cloned instruction. - std::unique_ptr Clone(const string& suffix = "clone"); + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). + std::unique_ptr Clone(const string& suffix = "clone", + HloModule* module = nullptr) const; // Clones the HLO instruction as above but with new shape and operands. + // If the module pointer is not nullptr, it will be the module where + // the cloned computations will be added to (in order to support deep + // cloning). std::unique_ptr CloneWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module = nullptr) const; // Returns the computations this instruction directly calls (if any). const std::vector& called_computations() const { @@ -820,6 +909,16 @@ class HloInstruction { } } + // Clears out the called computations. + // + // This is, in particular, necessary when inlining function bodies into their + // caller. If there were side-effecting operations in the called computations, + // the call itself is considered side-effecting and thus cannot be removed. By + // clearing out the computations, we reflect the fact that all side-effecting + // properties have been reflected in the caller, and make the call HLO + // removable. + void ClearCalledComputations() { called_computations_.clear(); } + // Returns true if this instruction performs an elementwise operation on // `operand_idx`-th operand. An instruction is elementwise on an operand iff, // after performing necessary implicit broadcast @@ -835,6 +934,12 @@ class HloInstruction { // Returns true if this instruction is elementwise on all its operands. bool IsElementwise() const; + // Returns true if this elementwise instruction implicitly broadcasts operand + // `operand_idx`. + // + // Precondition: this instruction should be an elementwise operation. + bool ImplicitlyBroadcastsOperand(int64 operand_idx) const; + // Returns true if this instruction is binary and elementwise. bool IsElementwiseBinary() const; @@ -917,14 +1022,6 @@ class HloInstruction { void RelayoutConstant(const Layout& new_layout, const ShapeIndex& shape_index = {}); - // Gets/sets the device assignment. - const OpDeviceAssignment& device_assignment() const { - return device_assignment_; - } - void set_device_assignment(const OpDeviceAssignment& device_assignment) { - device_assignment_ = device_assignment; - } - private: enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse }; @@ -981,8 +1078,8 @@ class HloInstruction { // Clones a fusion instruction with a new shape and operands. std::unique_ptr CloneFusionWithNewOperands( - const Shape& shape, - tensorflow::gtl::ArraySlice operands); + const Shape& shape, tensorflow::gtl::ArraySlice operands, + HloModule* module = nullptr) const; // Returns true if this instruction can legally have the dimensions field // set. Used for checking precondition of dimensions field accessors. @@ -1025,7 +1122,7 @@ class HloInstruction { std::unique_ptr literal_; // Constant index, only present for kGetTupleElement. - int64 tuple_index_ = 0; + int64 tuple_index_ = -1; // Dimensions present for some operations that require reshaping or // broadcasting, including Reshape, Reduce, ReduceWindow, and Reverse. @@ -1043,8 +1140,8 @@ class HloInstruction { std::vector slice_strides_; // The bit sizes for a reduce-precision operation. - int32 exponent_bits_; - int32 mantissa_bits_; + int32 exponent_bits_ = 0; + int32 mantissa_bits_ = 0; // Describes the [start, start + size) range size for a dynamic slice // ('start' is specified dynamically in the second operand of the operation). @@ -1057,6 +1154,9 @@ class HloInstruction { // The type of the fusion. Used by kFusion only. FusionKind fusion_kind_; + // The sharding, if one exists. + std::unique_ptr sharding_; + // For parameter instructions this field holds the parameter number. int64 parameter_number_ = 0; string parameter_name_; @@ -1094,11 +1194,11 @@ class HloInstruction { // A small float number added to the variance to avoid divide-by-zero error. // Only present for kBatchNormTraining. - float epsilon_; + float epsilon_ = 0.0f; // An integer value representing the index of the feature dimension. // Only present for kBatchNormTraining. - int64 feature_index_; + int64 feature_index_ = -1; // Represents a unique identifier for each Send/Recv instruction pair. // Only present for kSend or kRecv. @@ -1117,16 +1217,34 @@ class HloInstruction { // outer-most dimension first). std::vector outer_dimension_partitions_; - // Device assignment for the instruction. - OpDeviceAssignment device_assignment_; - TF_DISALLOW_COPY_AND_ASSIGN(HloInstruction); }; string ToString(HloInstruction::FusionKind kind); +StatusOr StringToFusionKind( + const string& kind_name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); +// Map classes that guarantee a deterministic iteration order when the key is +// an HloInstruction* or a const HloInstruction*. +// To make the iteration order over the map deterministic, the comparator +// should not be using the pointer values, but rather an intrinsic property of +// the hlo. +struct HloPtrComparator { + bool operator()(const HloInstruction* const& lhs, + const HloInstruction* const& rhs) const { + return lhs->unique_id() < rhs->unique_id(); + } +}; + +template +using HloInstructionMap = std::map; + +template +using ConstHloInstructionMap = + std::map; + } // namespace xla #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_ diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 3601d5cdbe66b305b4fa23fa5e1c519704befbb9..ddb623332c905fe406473e0c1a7adcea9782fdd0 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -59,15 +59,15 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override { + Status HandleConstant(HloInstruction* constant) override { EXPECT_EQ(0, count_.count(constant)); count_[constant] = GetCountsForNode(constant); return Status::OK(); } - Status HandleAdd(HloInstruction* add, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleAdd(HloInstruction* add) override { + auto lhs = add->operand(0); + auto rhs = add->operand(1); EXPECT_EQ(0, count_.count(add)); EXPECT_GT(count_.count(lhs), 0); EXPECT_GT(count_.count(rhs), 0); @@ -75,32 +75,26 @@ class OpAndUserCollectingVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleNegate(HloInstruction* negate, - HloInstruction* operand) override { + Status HandleNegate(HloInstruction* negate) override { + auto operand = negate->operand(0); EXPECT_EQ(0, count_.count(negate)); EXPECT_GT(count_.count(operand), 0); count_[negate] = GetCountsForNode(negate); return Status::OK(); } - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* /*function*/, - tensorflow::gtl::ArraySlice /*static_operands*/) - override { + Status HandleMap(HloInstruction* map) override { EXPECT_EQ(0, count_.count(map)); - for (HloInstruction* arg : operands) { + for (HloInstruction* arg : map->operands()) { EXPECT_GT(count_.count(arg), 0); } count_[map] = GetCountsForNode(map); return Status::OK(); } - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override { + Status HandleReduce(HloInstruction* reduce) override { + auto arg = reduce->operand(0); + auto init_value = reduce->operand(1); EXPECT_EQ(0, count_.count(reduce)); EXPECT_GT(count_.count(arg), 0); EXPECT_GT(count_.count(init_value), 0); @@ -706,6 +700,9 @@ TEST_F(HloInstructionTest, PreserveMetadataInFusionAndClone) { metadata, fusion->fused_expression_root()->metadata())); EXPECT_TRUE(protobuf_util::ProtobufEquals( metadata, fusion->fused_expression_root()->operand(0)->metadata())); + + auto cloned = fusion->CloneWithNewOperands(fusion->shape(), {}); + EXPECT_TRUE(protobuf_util::ProtobufEquals(metadata, fusion->metadata())); } TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { @@ -729,6 +726,23 @@ TEST_F(HloInstructionTest, PreserveOutfeedShapeThroughClone) { EXPECT_TRUE(ShapeUtil::Equal(clone10->outfeed_shape(), shape10)); } +TEST_F(HloInstructionTest, PreserveTupleShapeThroughClone) { + HloComputation::Builder builder(TestName()); + auto* constant = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR2({ + {1, 2}, + {3, 4}, + }))); + auto* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({constant, constant})); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {0}) + ->mutable_layout() = LayoutUtil::MakeLayout({0, 1}); + *ShapeUtil::GetMutableSubshape(tuple->mutable_shape(), {1}) + ->mutable_layout() = LayoutUtil::MakeLayout({1, 0}); + auto tuple_clone = tuple->Clone(); + EXPECT_TRUE(ShapeUtil::Equal(tuple_clone->shape(), tuple->shape())); +} + TEST_F(HloInstructionTest, FusionOpWithCalledComputations) { // Create a fusion instruction containing a single unary operation. const Shape scalar_shape = ShapeUtil::MakeShape(F32, {}); @@ -778,8 +792,8 @@ TEST_F(HloInstructionTest, ComplexFusionOp) { // sub = Sub(mul, clamp) // tuple = Tuple({sub, sub, mul, C1}) // - // Notable complexities are repeated operands in a same instruction, different - // shapes, use of value in different expressions. + // Notable complexities are repeated operands in the same instruction, + // different shapes, use of value in different expressions. auto c1 = builder.AddInstruction( HloInstruction::CreateConstant(Literal::CreateR0(1.1f))); auto c2 = builder.AddInstruction( @@ -1183,13 +1197,13 @@ TEST_F(HloInstructionTest, Stringification) { EXPECT_EQ(fusion->ToString(false, false), "%fusion = f32[5,20]{1,0} fusion:kTransposeDot(f32[5,10]{1,0} %x, " - "f32[20,10]{1,0} %y), calls=fused_computation"); + "f32[20,10]{1,0} %y), calls=%fused_computation"); HloInstruction* loop = builder.AddInstruction( HloInstruction::CreateWhile(sout, computation, computation, x)); EXPECT_EQ(loop->ToString(false, false), "%while = f32[5,20]{1,0} while(f32[5,10]{1,0} %x), " - "condition=TransposeDot, body=TransposeDot"); + "condition=%TransposeDot, body=%TransposeDot"); } } // namespace diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc index e022c4836d87866925ab7e56c2250d87d0f5dfec..4255d6086625dfb9a045e4431e968a5ee0106ac7 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers.cc @@ -73,5 +73,43 @@ void HloMatcher::DescribeTo(::std::ostream* os) const { } } +bool HloParameterMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->parameter_number() != parameter_number_) { + *listener << "has wrong parameter number (got " + << instruction->parameter_number() << ", want " + << parameter_number_ << ")"; + return false; + } + return true; +} + +bool HloGetTupleElementMatcher::MatchAndExplain( + const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const { + if (!HloMatcher::MatchAndExplain(instruction, listener)) { + return false; + } + if (instruction->tuple_index() != tuple_index_) { + *listener << "has wrong tuple index (got " << instruction->tuple_index() + << ", want " << tuple_index_ << ")"; + return false; + } + return true; +} + } // namespace testing + +void PrintTo(const HloInstruction* inst, ::std::ostream* os) { + *os << (inst ? inst->ToString() : "nullptr"); +} + +void PrintTo(HloInstruction* inst, ::std::ostream* os) { + PrintTo(const_cast(inst), os); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h index 79f17bbb6bd9bfc0c6ed48c68599ef51fbd27af8..4d4010b0253c57eec3587776308f0a5fbaa31304 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers.h +++ b/tensorflow/compiler/xla/service/hlo_matchers.h @@ -38,6 +38,36 @@ class HloMatcher : public ::testing::MatcherInterface { std::vector<::testing::Matcher> operands_; }; +// Custom matcher for parameters, which accepts a parameter number. +class HloParameterMatcher : public HloMatcher { + public: + explicit HloParameterMatcher(int64 parameter_number) + : HloMatcher(HloOpcode::kParameter, /*operands=*/{}), + parameter_number_(parameter_number) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 parameter_number_; +}; + +// Custom matcher for get-tuple-element instructions, which accepts a tuple +// index to match. +class HloGetTupleElementMatcher : public HloMatcher { + public: + explicit HloGetTupleElementMatcher( + ::testing::Matcher operand, int64 tuple_index) + : HloMatcher(HloOpcode::kGetTupleElement, /*operands=*/{operand}), + tuple_index_(tuple_index) {} + + bool MatchAndExplain(const HloInstruction* instruction, + ::testing::MatchResultListener* listener) const override; + + private: + int64 tuple_index_; +}; + // HloInstruction* matchers for opcode and operands. Example: // namespace op = xla::opcode_matchers; // EXPECT_THAT(instruction, @@ -72,16 +102,14 @@ HLO_MATCHER(Exp); HLO_MATCHER(Floor); HLO_MATCHER(Fusion); HLO_MATCHER(Ge); -HLO_MATCHER(GetTupleElement); HLO_MATCHER(Gt); -HLO_MATCHER(Index); HLO_MATCHER(Infeed); HLO_MATCHER(IsFinite); HLO_MATCHER(Le); HLO_MATCHER(Log); -HLO_MATCHER(LogicalAnd); -HLO_MATCHER(LogicalNot); -HLO_MATCHER(LogicalOr); +HLO_MATCHER(And); +HLO_MATCHER(Not); +HLO_MATCHER(Or); HLO_MATCHER(Lt); HLO_MATCHER(Map); HLO_MATCHER(Maximum); @@ -91,7 +119,6 @@ HLO_MATCHER(Ne); HLO_MATCHER(Negate); HLO_MATCHER(Outfeed); HLO_MATCHER(Pad); -HLO_MATCHER(Parameter); HLO_MATCHER(Power); HLO_MATCHER(Recv); HLO_MATCHER(Reduce); @@ -104,6 +131,9 @@ HLO_MATCHER(Rng); HLO_MATCHER(Select); HLO_MATCHER(SelectAndScatter); HLO_MATCHER(Send); +HLO_MATCHER(ShiftLeft); +HLO_MATCHER(ShiftRightLogical); +HLO_MATCHER(ShiftRightArithmetic); HLO_MATCHER(Sign); HLO_MATCHER(Slice); HLO_MATCHER(Sort); @@ -112,8 +142,44 @@ HLO_MATCHER(Tanh); HLO_MATCHER(Trace); HLO_MATCHER(Transpose); HLO_MATCHER(Tuple); -HLO_MATCHER(Update); HLO_MATCHER(While); + +// The special cases below let you check additional information about the +// HloInstruction, beyond just its opcode and operands. In all cases you can +// still use the generic matcher which doesn't check this info. +// +// Feel free to add additional custom matchers below. + +// - Parameter(N) matches parameter number N. +// - Parameter() matches any parameter. +inline ::testing::Matcher Parameter( + int64 parameter_number) { + return ::testing::MakeMatcher( + new ::xla::testing::HloParameterMatcher(parameter_number)); +} +inline ::testing::Matcher Parameter() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kParameter, {})); +} + +// GetTupleElement(operand, N) matches a GTE instruction which gets the N'th +// tuple element of operand, while GetTupleElement(operand) matches any GTE +// operation on operand, and GetTupleElement() matches any GTE operation at all. +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand, int64 tuple_index) { + return ::testing::MakeMatcher( + new ::xla::testing::HloGetTupleElementMatcher(operand, tuple_index)); +} +inline ::testing::Matcher GetTupleElement( + ::testing::Matcher operand) { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {operand})); +} +inline ::testing::Matcher GetTupleElement() { + return ::testing::MakeMatcher( + new ::xla::testing::HloMatcher(HloOpcode::kGetTupleElement, {})); +} + #undef HLO_MATCHER } // namespace opcode_matchers @@ -130,13 +196,8 @@ std::vector Pointers(const Container& container) { // Tell GMock to print HloInstruction* by value, so error messages are nice. // Has to be in the same namespace as 'HloInstruction'. -void PrintTo(const HloInstruction* inst, ::std::ostream* os) { - *os << (inst ? inst->ToString() : "nullptr"); -} - -void PrintTo(HloInstruction* inst, ::std::ostream* os) { - PrintTo(const_cast(inst), os); -} +void PrintTo(const HloInstruction* inst, ::std::ostream* os); +void PrintTo(HloInstruction* inst, ::std::ostream* os); } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index 3bdc73cafe910efb73276224d4677eb193519a89..659f3d8c26be97a45e5a219b5081334e4f5dcdab 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -45,10 +45,37 @@ HloModule::HloModule(const string& name, const HloModuleConfig& config) : name_(name), config_(config) {} HloComputation* HloModule::AddComputationInternal( - std::unique_ptr computation) { - computation->UniquifyName(&computation_name_uniquer_); - for (auto& instruction : computation->instructions()) { - instruction->UniquifyName(&instruction_name_uniquer_); + std::unique_ptr computation, bool is_entry, + bool uniquify_names) { + if (is_entry) { + CHECK_EQ(nullptr, entry_computation_); + entry_computation_ = computation.get(); + + // If the module configuration has no entry layout computation set, create a + // default one based on the program shape. + if (!config_.has_entry_computation_layout()) { + config_.SetDefaultComputationLayout( + entry_computation_->ComputeProgramShape()); + } + } + + if (uniquify_names) { + computation->UniquifyName(&computation_name_uniquer_); + for (auto* instruction : computation->instructions()) { + instruction->UniquifyName(&instruction_name_uniquer_); + } + } else { + // Don't uniquify the names of the computation or instruction, but we must + // run the names through the uniquifiers to prevent future name collisions + // for computations and instructions created later. + computation_name_uniquer_.GetUniqueName(computation->name()); + for (auto* instruction : computation->instructions()) { + instruction_name_uniquer_.GetUniqueName(instruction->name()); + } + } + + // Pick unique IDs for each instruction. + for (auto* instruction : computation->instructions()) { instruction->SetUniqueId(NewUniqueInstructionId()); } computation->set_parent(this); @@ -58,16 +85,8 @@ HloComputation* HloModule::AddComputationInternal( HloComputation* HloModule::AddEntryComputation( std::unique_ptr computation) { - CHECK_EQ(nullptr, entry_computation_); - entry_computation_ = computation.get(); - - // If the module configuration has no entry layout computation set, create a - // default one based on the program shape. - if (!config_.has_entry_computation_layout()) { - config_.SetDefaultComputationLayout( - entry_computation_->ComputeProgramShape()); - } - return AddComputationInternal(std::move(computation)); + return AddComputationInternal(std::move(computation), /*is_entry=*/true, + /*uniquify_names=*/true); } Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { @@ -83,7 +102,8 @@ Status HloModule::RemoveEmbeddedComputation(HloComputation* to_remove) { HloComputation* HloModule::AddEmbeddedComputation( std::unique_ptr computation) { - return AddComputationInternal(std::move(computation)); + return AddComputationInternal(std::move(computation), /*is_entry=*/false, + /*uniquify_names=*/true); } void HloModule::ReplaceComputations( @@ -94,7 +114,7 @@ void HloModule::ReplaceComputations( new_computations.reserve(computations_.size()); for (std::unique_ptr& computation : computations_) { - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { switch (instruction->opcode()) { case HloOpcode::kCall: case HloOpcode::kMap: @@ -150,14 +170,23 @@ void HloModule::ReplaceComputations( computations_ = std::move(new_computations); } -string HloModule::ToString() const { +string HloModule::ToString(bool include_large_constants) const { std::ostringstream s; s << "HloModule " << name() << ":\n\n"; - s << "ENTRY " << entry_computation()->ToString() << "\n\n"; - for (const std::unique_ptr& computation : computations_) { - if (computation.get() != entry_computation()) { - s << computation->ToString() << "\n\n"; + for (const HloComputation* computation : MakeComputationPostOrder()) { + // Fusion computations are emitted with their fusion instruction and + // therefore don't need to be emitted as a separate comptutation in the + // module. + if (computation->IsFusionComputation()) { + continue; } + if (computation == entry_computation()) { + s << "ENTRY "; + } + s << computation->ToString( + /*nested_level=*/0, + /*include_large_constants=*/include_large_constants) + << "\n\n"; } return s.str(); } @@ -167,12 +196,166 @@ HloModuleProto HloModule::ToProto() const { proto.set_name(name_); proto.set_entry_computation_name(entry_computation_->name()); for (const HloComputation* computation : MakeComputationPostOrder()) { + // Fusion computations are added when the fusion instructions are created by + // HloInstruction::CreateFromProto. + if (computation->IsFusionComputation()) { + continue; + } HloComputationProto computation_proto = computation->ToProto(); proto.add_computations()->Swap(&computation_proto); } return proto; } +namespace { + +// Construct a ProgramShape matching the shape of the parameters and root of the +// given module's entry computation. +StatusOr ProgramShapeFromProto(const HloModuleProto& module) { + const HloComputationProto* entry_computation = nullptr; + for (const HloComputationProto& computation : module.computations()) { + if (computation.name() == module.entry_computation_name()) { + entry_computation = &computation; + break; + } + } + TF_RET_CHECK(entry_computation != nullptr) + << "No computation with entry computation name" + << module.entry_computation_name(); + + tensorflow::gtl::FlatMap> parameters; + const HloInstructionProto* root = nullptr; + for (const HloInstructionProto& instruction : + entry_computation->instructions()) { + if (instruction.name() == entry_computation->root_name()) { + TF_RET_CHECK(root == nullptr) << "Entry computation has more than " + "one instruction with (root) name " + << instruction.name(); + root = &instruction; + } + if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) { + TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number())) + << "Entry computation has more than one parameter instruction " + "with parameter number " + << instruction.parameter_number(); + parameters[instruction.parameter_number()] = { + instruction.parameter_name(), &instruction.shape()}; + } + } + TF_RET_CHECK(root != nullptr) + << "Entry computation is missing root instruction named " + << entry_computation->root_name(); + + ProgramShape program_shape; + *program_shape.mutable_result() = root->shape(); + for (int64 i = 0; i < parameters.size(); ++i) { + TF_RET_CHECK(ContainsKey(parameters, i)) + << "Entry computation missing parameter number " << i; + const string& name = parameters.at(i).first; + const Shape& shape = *parameters.at(i).second; + *program_shape.add_parameters() = shape; + program_shape.add_parameter_names(name); + } + + return std::move(program_shape); +} + +} // namespace + +/* static */ +StatusOr> HloModule::CreateFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config, + const VersionedComputationHandle& entry_computation_handle) { + // The ProgramShape in the passed in module config must match the shapes of + // the entry parameters and root. + TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape, + ProgramShapeFromProto(proto)); + TF_RET_CHECK(expected_program_shape.parameters_size() == + module_config.entry_computation_layout().parameter_count()); + for (int i = 0; i < expected_program_shape.parameters_size(); ++i) { + const Shape& parameter_shape = + module_config.entry_computation_layout().parameter_layout(i).shape(); + TF_RET_CHECK( + ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape)) + << "HloModuleConfig has different shape for parameter " << i + << " than the HLO module. Expected: " + << ShapeUtil::HumanStringWithLayout( + expected_program_shape.parameters(i)) + << ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape); + } + const Shape& result_shape = + module_config.entry_computation_layout().result_layout().shape(); + TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape)) + << "HloModuleConfig has different result shape than the HLO module. " + "Expected: " + << ShapeUtil::HumanStringWithLayout(expected_program_shape.result()) + << ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape); + + auto module = MakeUnique(proto.name(), entry_computation_handle, + module_config); + + tensorflow::gtl::FlatMap computation_map; + for (const HloComputationProto& computation_proto : proto.computations()) { + TF_ASSIGN_OR_RETURN(std::unique_ptr computation, + HloComputation::CreateFromProto( + module.get(), computation_proto, &computation_map)); + CHECK_NE(computation.get(), nullptr); + TF_RET_CHECK(!ContainsKey(computation_map, computation->name())); + string computation_name = computation->name(); + // Don't uniquify names because we want names to be stable across + // serialization and deserialization. + computation_map[computation_name] = module->AddComputationInternal( + std::move(computation), + /*is_entry=*/proto.entry_computation_name() == computation_name, + /*uniquify_names=*/false); + } + TF_RET_CHECK(module->entry_computation_ != nullptr); + + // Because we didn't uniquify the names, double-check that the instruction and + // computation names are unique from the proto. + tensorflow::gtl::FlatSet computation_names; + tensorflow::gtl::FlatSet instruction_names; + for (HloComputation* computation : module->computations()) { + if (computation->IsFusionComputation()) { + continue; + } + + TF_RET_CHECK(!ContainsKey(computation_names, computation->name())) + << "Computation name is not unique: " << computation->name(); + computation_names.insert(computation->name()); + for (HloInstruction* instruction : computation->instructions()) { + TF_RET_CHECK(!ContainsKey(instruction_names, instruction->name())) + << "Instruction name is not unique: " << instruction->name(); + instruction_names.insert(instruction->name()); + } + } + + return std::move(module); +} + +/* static */ +StatusOr HloModule::CreateModuleConfigFromProto( + const HloModuleProto& module) { + TF_ASSIGN_OR_RETURN(ProgramShape program_shape, + ProgramShapeFromProto(module)); + + HloModuleConfig module_config(program_shape); + + // The module config is constructed with default layouts regardless of what is + // passed in via the ProgramShape. Set the layouts to the appropriate values. + ComputationLayout* entry_layout = + module_config.mutable_entry_computation_layout(); + for (int64 i = 0; i < entry_layout->parameter_count(); ++i) { + TF_RETURN_IF_ERROR( + entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape( + program_shape.parameters(i))); + } + TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape( + program_shape.result())); + + return module_config; +} + namespace { // Returns whether `hlo` is used outside the given subcomputation. // `instructions_in_subcomputation` is the instruction set of the given @@ -266,7 +449,7 @@ HloInstruction* HloModule::OutlineExpressionFromComputation( VLOG(2) << "as a call " << call->ToString(); VLOG(2) << "to " << nested_computation->ToString(); - TF_CHECK_OK(computation->ReplaceUsesOfInstruction(output, call)); + TF_CHECK_OK(output->ReplaceAllUsesWith(call)); for (auto i = instructions_to_outline.rbegin(); i != instructions_to_outline.rend(); ++i) { TF_CHECK_OK(computation->RemoveInstruction(*i)); @@ -281,7 +464,7 @@ std::list HloModule::MakeComputationPostOrder() const { // module). std::set nonroot_computations; for (auto& computation : computations_) { - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { for (HloComputation* called_computation : instruction->called_computations()) { nonroot_computations.insert(called_computation); @@ -313,6 +496,17 @@ std::list HloModule::MakeComputationPostOrder() const { return post_order; } +std::vector HloModule::MakeNonfusionComputations() const { + std::vector result; + for (auto* c : computations()) { + if (c->IsFusionComputation()) { + continue; + } + result.push_back(c); + } + return result; +} + std::unique_ptr HloModule::Clone(const string& suffix) const { VLOG(1) << "Cloning module :" << name_ << " --> " << suffix << "\n"; auto module = MakeUnique(name_ + "-" + suffix); @@ -333,7 +527,7 @@ std::unique_ptr HloModule::Clone(const string& suffix) const { } for (auto& cloned_computation : module->computations_) { - for (auto& instruction : cloned_computation->instructions()) { + for (auto* instruction : cloned_computation->instructions()) { // Rewrite instruction's called_computation to point to the cloned // computations. instruction->ReplaceCalledComputations( diff --git a/tensorflow/compiler/xla/service/hlo_module.h b/tensorflow/compiler/xla/service/hlo_module.h index fe41fe2fd9f9f5c36805c4e2856a910e240d30dd..6469851791ddb66c6fb17aa8d7c80b04c879a67b 100644 --- a/tensorflow/compiler/xla/service/hlo_module.h +++ b/tensorflow/compiler/xla/service/hlo_module.h @@ -23,6 +23,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/iterator_util.h" #include "tensorflow/compiler/xla/service/hlo.pb.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -31,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/versioned_computation_handle.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/gtl/iterator_range.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/mutex.h" @@ -96,19 +98,60 @@ class HloModule { return entry_computation_handle_; } - const std::vector>& computations() const { - return computations_; + // Gets the computations in this module. + // + // Returns a view of HloComputation*s, so you can iterate over this in the + // natural way: + // + // for (HloComputation* c : module->computations()) { ... } + // + tensorflow::gtl::iterator_range>::const_iterator>> + computations() const { + return {MakeUnwrappingIterator(computations_.begin()), + MakeUnwrappingIterator(computations_.end())}; } + tensorflow::gtl::iterator_range>::iterator>> + computations() { + return {MakeUnwrappingIterator(computations_.begin()), + MakeUnwrappingIterator(computations_.end())}; + } + + // Gets the number of computations in this module. + int64 computation_count() const { return computations_.size(); } // Compute and return a post order of all computations in the module. The sort // is defined like so: if computation A has an instruction which calls // computation B, then A will appear after B in the sort. std::list MakeComputationPostOrder() const; + // Gets the computations in this module which aren't for fusion nodes. + // + // Postcondition: All computations in the returned list have + // !IsFusionComputation(). + // + // Note: Callers can and do rely on the return value here being a *snapshot* + // of the module's non-fusion computations -- that is, it's OK to add or + // remove computations from a module while iterating over + // MakeNonfusionComputations(). + std::vector MakeNonfusionComputations() const; + const HloModuleConfig& config() const { return config_; } - string ToString() const; + string ToString(bool include_large_constants = false) const; + + // Convert an HloModule to or from a proto. HloModuleProto ToProto() const; + static StatusOr> CreateFromProto( + const HloModuleProto& proto, const HloModuleConfig& module_config, + const VersionedComputationHandle& entry_computation_handle = + VersionedComputationHandle()); + + // Creates and returns an HloModuleConfig with an appropriate program shape + // for the HLO module in the given proto. + static StatusOr CreateModuleConfigFromProto( + const HloModuleProto& module); // Outlines the given expression from the given computation. // instructions_to_outline contains the instructions that form the expression. @@ -144,7 +187,8 @@ class HloModule { private: HloComputation* AddComputationInternal( - std::unique_ptr computation); + std::unique_ptr computation, bool is_entry, + bool uniquify_names); const string name_; HloModuleConfig config_; diff --git a/tensorflow/compiler/xla/service/hlo_module_config.h b/tensorflow/compiler/xla/service/hlo_module_config.h index 2299200b5be969c065fded840709a3d6034efe47..4a7ead9c104d2ed50d5c895b3cdf2d3767ae16e8 100644 --- a/tensorflow/compiler/xla/service/hlo_module_config.h +++ b/tensorflow/compiler/xla/service/hlo_module_config.h @@ -67,11 +67,6 @@ class HloModuleConfig { bool hlo_profiling_enabled() const { return hlo_profiling_enabled_; } void enable_hlo_profiling(bool enabled) { hlo_profiling_enabled_ = enabled; } - bool has_hybrid_result() const { return has_hybrid_result_; } - void set_has_hybrid_result(bool has_hybrid_result) { - has_hybrid_result_ = has_hybrid_result; - } - // Sets/returns the module seed set during execution. void set_seed(uint64 seed) { seed_ = seed; } uint64 seed() const { return seed_; } diff --git a/tensorflow/compiler/xla/service/hlo_module_test.cc b/tensorflow/compiler/xla/service/hlo_module_test.cc index 20eef2f7d53251a374971e55441f6a4585e9b35c..bf6440d66cac0d3a929c377202b212aba262f887 100644 --- a/tensorflow/compiler/xla/service/hlo_module_test.cc +++ b/tensorflow/compiler/xla/service/hlo_module_test.cc @@ -101,7 +101,7 @@ TEST_F(HloModuleTest, CloneTest) { for (auto origin = post_order.begin(), copied = post_order_copied.begin(); origin != post_order.end() && copied != post_order_copied.end(); ++origin, ++copied) { - EXPECT_EQ((*origin)->name() + "copy", (*copied)->name()); + EXPECT_EQ((*origin)->name() + ".copy", (*copied)->name()); } } @@ -125,6 +125,26 @@ TEST_F(HloModuleTest, DiamondComputationsPostOrder) { EXPECT_EQ(post_order.front(), computation1); } +TEST_F(HloModuleTest, LargeConstantToString) { + // Create a module with a single computation. + auto module = CreateNewModule(); + auto builder = HloComputation::Builder("Constant"); + std::vector values(16, 42.0); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1(values))); + module->AddEntryComputation(builder.Build()); + + EXPECT_EQ( + "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "ROOT %constant = f32[16]{0} constant({...})\n}\n\n", + module->ToString(/*include_large_constants=*/false)); + EXPECT_EQ( + "HloModule LargeConstantToString:\n\nENTRY %Constant () -> f32[16] {\n " + "ROOT %constant = f32[16]{0} constant({42, 42, 42, 42, 42, 42, 42, 42, " + "42, 42, 42, 42, 42, 42, 42, 42})\n}\n\n", + module->ToString(/*include_large_constants=*/true)); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc index 83fe6ef6c967f865333eff51b04a33b1d11ffa7e..d1eaf357855205f1e9867e86f3042b96b6beff97 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode.cc @@ -15,188 +15,65 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { string HloOpcodeString(HloOpcode opcode) { - // Note: Do not use ':' in opcode strings. It is used as a special character - // in these places: - // - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to - // separate the opcode from the fusion kind - // - In fully qualified names (HloInstruction::FullyQualifiedName()), to - // separate the qualifiers (name of the computation and potentially the - // fusion instruction) from the name switch (opcode) { - case HloOpcode::kAbs: - return "abs"; - case HloOpcode::kAdd: - return "add"; - case HloOpcode::kBatchNormTraining: - return "batch-norm-training"; - case HloOpcode::kBatchNormInference: - return "batch-norm-inference"; - case HloOpcode::kBatchNormGrad: - return "batch-norm-grad"; - case HloOpcode::kBitcast: - return "bitcast"; - case HloOpcode::kBroadcast: - return "broadcast"; - case HloOpcode::kCall: - return "call"; - case HloOpcode::kClamp: - return "clamp"; - case HloOpcode::kConcatenate: - return "concatenate"; - case HloOpcode::kConstant: - return "constant"; - case HloOpcode::kConvert: - return "convert"; - case HloOpcode::kConvolution: - return "convolution"; - case HloOpcode::kCos: - return "cosine"; - case HloOpcode::kCrossReplicaSum: - return "cross-replica-sum"; - case HloOpcode::kCustomCall: - return "custom-call"; - case HloOpcode::kCopy: - return "copy"; - case HloOpcode::kDivide: - return "divide"; - case HloOpcode::kDot: - return "dot"; - case HloOpcode::kDynamicSlice: - return "dynamic-slice"; - case HloOpcode::kDynamicUpdateSlice: - return "dynamic-update-slice"; - case HloOpcode::kEq: - return "equal-to"; - case HloOpcode::kExp: - return "exponential"; - case HloOpcode::kFloor: - return "floor"; - case HloOpcode::kCeil: - return "ceil"; - case HloOpcode::kFusion: - return "fusion"; - case HloOpcode::kGe: - return "greater-than-or-equal-to"; - case HloOpcode::kGetTupleElement: - return "get-tuple-element"; - case HloOpcode::kGt: - return "greater-than"; - case HloOpcode::kIndex: - return "index"; - case HloOpcode::kInfeed: - return "infeed"; - case HloOpcode::kIsFinite: - return "is-finite"; - case HloOpcode::kLe: - return "less-than-or-equal-to"; - case HloOpcode::kLog: - return "log"; - case HloOpcode::kLogicalAnd: - return "logical-and"; - case HloOpcode::kLogicalOr: - return "logical-or"; - case HloOpcode::kLogicalNot: - return "logical-not"; - case HloOpcode::kLt: - return "less-than"; - case HloOpcode::kMap: - return "map"; - case HloOpcode::kMaximum: - return "maximum"; - case HloOpcode::kMinimum: - return "minimum"; - case HloOpcode::kMultiply: - return "multiply"; - case HloOpcode::kNe: - return "not-equal-to"; - case HloOpcode::kNegate: - return "negate"; - case HloOpcode::kOutfeed: - return "outfeed"; - case HloOpcode::kPad: - return "pad"; - case HloOpcode::kParameter: - return "parameter"; - case HloOpcode::kPower: - return "power"; - case HloOpcode::kRecv: - return "recv"; - case HloOpcode::kReduce: - return "reduce"; - case HloOpcode::kReducePrecision: - return "reduce-precision"; - case HloOpcode::kReduceWindow: - return "reduce-window"; - case HloOpcode::kRemainder: - return "remainder"; - case HloOpcode::kReshape: - return "reshape"; - case HloOpcode::kReverse: - return "reverse"; - case HloOpcode::kRng: - return "rng"; - case HloOpcode::kRoundNearestAfz: - return "round-nearest-afz"; - case HloOpcode::kSelectAndScatter: - return "select-and-scatter"; - case HloOpcode::kSelect: - return "select"; - case HloOpcode::kSend: - return "send"; - case HloOpcode::kSign: - return "sign"; - case HloOpcode::kSin: - return "sine"; - case HloOpcode::kSlice: - return "slice"; - case HloOpcode::kSort: - return "sort"; - case HloOpcode::kSubtract: - return "subtract"; - case HloOpcode::kTanh: - return "tanh"; - case HloOpcode::kTrace: - return "trace"; - case HloOpcode::kTranspose: - return "transpose"; - case HloOpcode::kTuple: - return "tuple"; - case HloOpcode::kUpdate: - return "update"; - case HloOpcode::kWhile: - return "while"; +#define CASE_OPCODE_STRING(enum_name, opcode_name, ...) \ + case HloOpcode::enum_name: \ + return opcode_name; + HLO_OPCODE_LIST(CASE_OPCODE_STRING) +#undef CASE_OPCODE_STRING } } +StatusOr StringToHloOpcode(const string& opcode_name) { + static auto* opcode_map = new tensorflow::gtl::FlatMap({ +#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \ + {opcode_name, HloOpcode::enum_name}, + HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY) +#undef STRING_TO_OPCODE_ENTRY + }); + auto it = opcode_map->find(opcode_name); + if (it == opcode_map->end()) { + return InvalidArgument("Unknown opcode: %s", opcode_name.c_str()); + } + return it->second; +} + +#define CHECK_DEFAULT(property_name, opcode_name) false +#define CHECK_PROPERTY(property_name, opcode_name, value) \ + (value & property_name) +#define RESOLVE(_1, _2, target, ...) target +#define HAS_PROPERTY(property, ...) \ + RESOLVE(__VA_ARGS__, CHECK_PROPERTY, CHECK_DEFAULT)(property, __VA_ARGS__) + bool HloOpcodeIsComparison(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kGe: - case HloOpcode::kGt: - case HloOpcode::kLe: - case HloOpcode::kLt: - case HloOpcode::kEq: - case HloOpcode::kNe: - return true; - default: - return false; +#define CASE_IS_COMPARISON(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsComparison, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_COMPARISON) +#undef CASE_IS_COMPARISON } } bool HloOpcodeIsVariadic(HloOpcode opcode) { switch (opcode) { - case HloOpcode::kCall: - case HloOpcode::kConcatenate: - case HloOpcode::kFusion: - case HloOpcode::kMap: - case HloOpcode::kTuple: - return true; - default: - return false; +#define CASE_IS_VARIADIC(enum_name, ...) \ + case HloOpcode::enum_name: \ + return HAS_PROPERTY(kHloOpcodeIsVariadic, __VA_ARGS__); + HLO_OPCODE_LIST(CASE_IS_VARIADIC) +#undef CASE_IS_VARIADIC } } +#undef HAS_PROPERTY +#undef RESOLVE +#undef CHECK_DEFAULT +#undef CHECK_PROPERTY + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 7b23249640b0dcfdd510caf27bf57bb1f2f6850e..d68fc20321152f6a2ede1234180bee0db110f503 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" namespace xla { @@ -27,83 +28,120 @@ namespace xla { // present in the XLA service protobuf. // // See the XLA documentation for the semantics of each opcode. +// +// Each entry has the format: +// (enum_name, opcode_name) +// or +// (enum_name, opcode_name, p1 | p2 | ...) +// +// with p1, p2, ... are members of HloOpcodeProperty. They are combined +// using bitwise-or. +// +// Note: Do not use ':' in opcode names. It is used as a special character +// in these places: +// - In extended opcode strings (HloInstruction::ExtendedOpcodeString()), to +// separate the opcode from the fusion kind +// - In fully qualified names (HloInstruction::FullyQualifiedName()), to +// separate the qualifiers (name of the computation and potentially the +// fusion instruction) from the name +#define HLO_OPCODE_LIST(V) \ + V(kAbs, "abs") \ + V(kAdd, "add") \ + V(kAtan2, "atan2") \ + V(kBatchNormGrad, "batch-norm-grad") \ + V(kBatchNormInference, "batch-norm-inference") \ + V(kBatchNormTraining, "batch-norm-training") \ + V(kBitcast, "bitcast") \ + V(kBroadcast, "broadcast") \ + V(kCall, "call", kHloOpcodeIsVariadic) \ + V(kCeil, "ceil") \ + V(kClamp, "clamp") \ + V(kComplex, "complex") \ + V(kConcatenate, "concatenate", kHloOpcodeIsVariadic) \ + V(kConstant, "constant") \ + V(kConvert, "convert") \ + V(kConvolution, "convolution") \ + V(kCopy, "copy") \ + V(kCos, "cosine") \ + V(kCrossReplicaSum, "cross-replica-sum") \ + V(kCustomCall, "custom-call") \ + V(kDivide, "divide") \ + V(kDot, "dot") \ + V(kDynamicSlice, "dynamic-slice") \ + V(kDynamicUpdateSlice, "dynamic-update-slice") \ + V(kEq, "equal-to", kHloOpcodeIsComparison) \ + V(kExp, "exponential") \ + V(kFloor, "floor") \ + V(kFusion, "fusion", kHloOpcodeIsVariadic) \ + V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kGetTupleElement, "get-tuple-element") \ + V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kImag, "imag") \ + V(kInfeed, "infeed") \ + V(kIsFinite, "is-finite") \ + V(kLe, "less-than-or-equal-to", kHloOpcodeIsComparison) \ + V(kLog, "log") \ + V(kAnd, "and") \ + V(kNot, "not") \ + V(kOr, "or") \ + V(kLt, "less-than", kHloOpcodeIsComparison) \ + V(kMap, "map", kHloOpcodeIsVariadic) \ + V(kMaximum, "maximum") \ + V(kMinimum, "minimum") \ + V(kMultiply, "multiply") \ + V(kNe, "not-equal-to", kHloOpcodeIsComparison) \ + V(kNegate, "negate") \ + V(kOutfeed, "outfeed") \ + V(kPad, "pad") \ + V(kParameter, "parameter") \ + V(kPower, "power") \ + V(kReal, "real") \ + V(kRecv, "recv") \ + V(kReduce, "reduce") \ + V(kReducePrecision, "reduce-precision") \ + V(kReduceWindow, "reduce-window") \ + V(kRemainder, "remainder") \ + V(kReshape, "reshape") \ + V(kReverse, "reverse") \ + V(kRng, "rng") \ + V(kRoundNearestAfz, "round-nearest-afz") \ + V(kSelect, "select") \ + V(kSelectAndScatter, "select-and-scatter") \ + V(kSend, "send") \ + V(kShiftLeft, "shift-left") \ + V(kShiftRightArithmetic, "shift-right-arithmetic") \ + V(kShiftRightLogical, "shift-right-logical") \ + V(kSign, "sign") \ + V(kSin, "sine") \ + V(kSlice, "slice") \ + V(kSort, "sort") \ + V(kSubtract, "subtract") \ + V(kTanh, "tanh") \ + V(kTrace, "trace") \ + V(kTranspose, "transpose") \ + V(kTuple, "tuple", kHloOpcodeIsVariadic) \ + V(kWhile, "while") + enum class HloOpcode { - kAbs, - kAdd, - kBatchNormGrad, - kBatchNormInference, - kBatchNormTraining, - kBitcast, - kBroadcast, - kCall, - kCeil, - kClamp, - kConcatenate, - kConstant, - kConvert, - kConvolution, - kCopy, - kCos, - kCrossReplicaSum, - kCustomCall, - kDivide, - kDot, - kDynamicSlice, - kDynamicUpdateSlice, - kEq, - kExp, - kFloor, - kFusion, - kGe, - kGetTupleElement, - kGt, - kIndex, - kInfeed, - kIsFinite, - kLe, - kLog, - kLogicalAnd, - kLogicalNot, - kLogicalOr, - kLt, - kMap, - kMaximum, - kMinimum, - kMultiply, - kNe, - kNegate, - kOutfeed, - kPad, - kParameter, - kPower, - kRecv, - kReduce, - kReducePrecision, - kReduceWindow, - kRemainder, - kReshape, - kReverse, - kRng, - kRoundNearestAfz, - kSelect, - kSelectAndScatter, - kSend, - kSign, - kSin, - kSlice, - kSort, - kSubtract, - kTanh, - kTrace, - kTranspose, - kTuple, - kUpdate, - kWhile, +#define DECLARE_ENUM(enum_name, opcode_name, ...) enum_name, + HLO_OPCODE_LIST(DECLARE_ENUM) +#undef DECLARE_ENUM +}; + +// List of properties associated with opcodes. +// Properties are defined as increasing powers of two, so that we can use +// bitwise-or to combine properties, and bitwise-and to test for them. +enum HloOpcodeProperty { + kHloOpcodeIsComparison = 1 << 0, + kHloOpcodeIsVariadic = 1 << 1, }; // Returns a string representation of the opcode. string HloOpcodeString(HloOpcode opcode); +// Returns a string representation of the opcode. +StatusOr StringToHloOpcode(const string& opcode_name); + inline std::ostream& operator<<(std::ostream& os, HloOpcode opcode) { return os << HloOpcodeString(opcode); } @@ -116,7 +154,9 @@ bool HloOpcodeIsVariadic(HloOpcode opcode); // Returns the number of HloOpcode values. inline const uint32_t HloOpcodeCount() { - return static_cast(HloOpcode::kWhile) + 1; +#define HLO_COUNT_ONE(...) +1 +#define HLO_XLIST_LENGTH(list) list(HLO_COUNT_ONE) + return HLO_XLIST_LENGTH(HLO_OPCODE_LIST); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_opcode_test.cc b/tensorflow/compiler/xla/service/hlo_opcode_test.cc index 892c89f9df209f2e39005a4901feae6699ce4d0b..cd2ce5c69f030c65b889d67e082a3677b8739ddb 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode_test.cc +++ b/tensorflow/compiler/xla/service/hlo_opcode_test.cc @@ -26,5 +26,46 @@ TEST(HloOpcodeTest, StringifyMultiply) { ASSERT_EQ("multiply", HloOpcodeString(HloOpcode::kMultiply)); } +TEST(HloOpcodeTest, OpcodeProperties) { + // Test counting macro. +#define SOME_LIST(X) \ + X(One) \ + X(Two) \ + X(Three) + EXPECT_EQ(3, HLO_XLIST_LENGTH(SOME_LIST)); +#undef SOME_LIST + + for (int i = 0; i < HloOpcodeCount(); ++i) { + auto opcode = static_cast(i); + // Test round-trip conversion to and from string. + EXPECT_EQ(opcode, StringToHloOpcode(HloOpcodeString(opcode)).ValueOrDie()); + + // Test some properties. + switch (opcode) { + case HloOpcode::kEq: + case HloOpcode::kNe: + case HloOpcode::kGt: + case HloOpcode::kLt: + case HloOpcode::kGe: + case HloOpcode::kLe: + EXPECT_TRUE(HloOpcodeIsComparison(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsComparison(opcode)); + } + switch (opcode) { + case HloOpcode::kCall: + case HloOpcode::kConcatenate: + case HloOpcode::kFusion: + case HloOpcode::kMap: + case HloOpcode::kTuple: + EXPECT_TRUE(HloOpcodeIsVariadic(opcode)); + break; + default: + EXPECT_FALSE(HloOpcodeIsVariadic(opcode)); + } + } +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc index 3612c51ee821b1aab9c27f730270cbebd596ebb1..37009369797693dcd06647fad845bb0c004cec67 100644 --- a/tensorflow/compiler/xla/service/hlo_ordering.cc +++ b/tensorflow/compiler/xla/service/hlo_ordering.cc @@ -253,7 +253,7 @@ bool PredecessorHloOrdering::ExecutesBeforeInSameComputation( string PredecessorHloOrdering::ToStringHelper(const string& name) const { std::vector pieces; pieces.push_back(name); - for (auto& computation : module_->computations()) { + for (auto* computation : module_->computations()) { pieces.push_back(tensorflow::strings::Printf("computation %s:", computation->name().c_str())); const auto all = computation->MakeInstructionPostOrder(); @@ -261,7 +261,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const { pieces.push_back(tensorflow::strings::Printf( " %s predecessors:", instruction->name().c_str())); for (auto predecessor : all) { - if (predecessors_.at(computation.get()) + if (predecessors_.at(computation) ->IsReachable(predecessor, instruction)) { pieces.push_back( tensorflow::strings::Printf(" %s", predecessor->name().c_str())); @@ -277,12 +277,8 @@ DependencyHloOrdering::DependencyHloOrdering(const HloModule* module) // Compute predecessor relationships between all instructions to determine // ordering based on dependencies. ExecutesBefore will return true iff there // exists a path in the HLO computation graph from 'a' to 'b'. - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - predecessors_.emplace(computation.get(), - computation->ComputeReachability()); + for (auto* computation : module->MakeNonfusionComputations()) { + predecessors_.emplace(computation, computation->ComputeReachability()); } } @@ -323,7 +319,7 @@ SequentialHloOrdering::SequentialOrder( string SequentialHloOrdering::ToString() const { std::vector pieces; pieces.push_back("SequentialHloOrdering"); - for (auto& computation : module_->computations()) { + for (auto* computation : module_->computations()) { pieces.push_back(tensorflow::strings::Printf("computation %s order:", computation->name().c_str())); // Gather all instructions in the module sequence for this computation and @@ -331,7 +327,7 @@ string SequentialHloOrdering::ToString() const { std::vector instructions; for (auto& instruction_position : order_position_) { const HloInstruction* instruction = instruction_position.first; - if (instruction->parent() == computation.get()) { + if (instruction->parent() == computation) { instructions.push_back(instruction); } } diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc index ed7b6c71bc6619b0cb93f226eb10de1023749109..53bd46a641afcba1b9551895955742e74a9f374b 100644 --- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc +++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc @@ -59,6 +59,7 @@ StatusOr HloPassPipeline::Run(HloModule* module) { for (auto& invariant_checker : invariant_checkers_) { VLOG(1) << " Invariant checker " << invariant_checker->name(); StatusOr changed_status = invariant_checker->Run(module); + VLOG(1) << " Invariant checker done " << invariant_checker->name(); if (!changed_status.ok()) { VLOG(2) << "Module failed invariant check:"; XLA_VLOG_LINES(2, module->ToString()); diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index 8b1e343bd925cedc835efa93f5d1fce14bd60703..c96df50e79a3c6d4ca5f8e7e0abec33cdfca1c70 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -761,9 +761,9 @@ bool MemoryUsageTracker::Check() const { }; // Verify buffers_defined per instruction. - for (auto& instruction : computation_->instructions()) { + for (auto* instruction : computation_->instructions()) { const BufferIdList& defined_buffers = - instruction_list_.GetItem(instruction.get())->buffers_defined; + instruction_list_.GetItem(instruction)->buffers_defined; CHECK(elements_are_unique(defined_buffers)) << "Instruction " << instruction->name() << " does not have unique defined buffers: " @@ -774,7 +774,7 @@ bool MemoryUsageTracker::Check() const { }); for (const Buffer& buffer : buffers_) { - if (buffer.defining_instruction->instruction == instruction.get()) { + if (buffer.defining_instruction->instruction == instruction) { CHECK(std::find(defined_buffers.begin(), defined_buffers.end(), buffer.id) != defined_buffers.end()) << "Instruction " << instruction->name() @@ -784,9 +784,9 @@ bool MemoryUsageTracker::Check() const { } // Verify buffers_used per instruction. - for (auto& instruction : computation_->instructions()) { + for (auto* instruction : computation_->instructions()) { const BufferIdList& used_buffers = - instruction_list_.GetItem(instruction.get())->buffers_used; + instruction_list_.GetItem(instruction)->buffers_used; CHECK(elements_are_unique(used_buffers)) << "Instruction " << instruction->name() << " does not have unique used buffers: " @@ -1151,8 +1151,8 @@ StatusOr HloRematerialization::RematerializeComputation( // Verify some invariants on the memory tracker. CHECK_EQ(memory_tracker.memory_usage(), 0); - for (auto& instruction : computation->instructions()) { - CHECK(memory_tracker.IsPlaced(instruction.get())); + for (auto* instruction : computation->instructions()) { + CHECK(memory_tracker.IsPlaced(instruction)); } VLOG(1) << "In computation " << computation->name() << " rematerialized " @@ -1256,23 +1256,18 @@ StatusOr HloRematerialization::Run( // After DCE, the module sequence may include instructions which no longer // exist. - for (const auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - if (sequence->at(computation.get()).size() != - computation->instruction_count()) { + for (const auto* computation : module->MakeNonfusionComputations()) { + if (sequence->at(computation).size() != computation->instruction_count()) { // A size mismatch between the computation instruction count and the size // of the ordering of instructions can only be caused by DCE. Rebuild the // order by removing the deleted instructions from the order. tensorflow::gtl::FlatSet instruction_set; for (const auto& instruction : computation->instructions()) { - instruction_set.insert(instruction.get()); + instruction_set.insert(instruction); } // Move the old order into a temporary vector, then build new order // inplace. - std::vector& order = - sequence->at(computation.get()); + std::vector& order = sequence->at(computation); std::vector old_order; using std::swap; swap(order, old_order); @@ -1281,7 +1276,7 @@ StatusOr HloRematerialization::Run( [&instruction_set](const HloInstruction* instruction) { return ContainsKey(instruction_set, instruction); }); - TF_RET_CHECK(sequence->at(computation.get()).size() == + TF_RET_CHECK(sequence->at(computation).size() == computation->instruction_count()); } } diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc index 7dc42ae7972e3457179ef6119b978c4eec546542..d88aa4bb567c6c5f6eab54f12239bf7040339c39 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization_test.cc @@ -385,7 +385,7 @@ TEST_F(HloRematerializationTest, InstructionRematerializedMultipleTimes) { auto count_broadcasts = [](const HloComputation* computation) { int64 bcast_count = 0; - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBroadcast) { bcast_count++; } diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc new file mode 100644 index 0000000000000000000000000000000000000000..f463e57d995c0f0549872a1a0bf20a3ead626dc8 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.cc @@ -0,0 +1,213 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_runner.h" + +#include +#include +#include + +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/compiler/xla/layout_util.h" +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/executable.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace se = ::perftools::gputools; + +namespace xla { + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloProtoFile(const std::string& filename, + const DebugOptions& debug_options) { + HloProto proto; + + const Status s = + tensorflow::ReadBinaryProto(tensorflow::Env::Default(), filename, &proto); + + if (!s.ok()) { + const Status s2 = + tensorflow::ReadTextProto(tensorflow::Env::Default(), filename, &proto); + if (!s2.ok()) { + return Status(s2.code(), s.error_message() + "\n" + s2.error_message()); + } + } + + TF_ASSIGN_OR_RETURN( + HloModuleConfig config, + HloModule::CreateModuleConfigFromProto(proto.hlo_module())); + config.set_debug_options(debug_options); + TF_ASSIGN_OR_RETURN(auto module, + HloModule::CreateFromProto(proto.hlo_module(), config)); + return std::move(module); +} + +/*static*/ StatusOr> +HloRunner::ReadModuleFromHloTextDumpFile(const std::string& filename, + const DebugOptions& debug_options) { + string hlo_string; + TF_RETURN_IF_ERROR(tensorflow::ReadFileToString(tensorflow::Env::Default(), + filename, &hlo_string)); + HloModuleConfig config; + config.set_debug_options(debug_options); + return tools::Parse(hlo_string, config); +} + +/*static*/ StatusOr> HloRunner::ReadModule( + const std::string& filename, const DebugOptions& debug_options) { + auto module = HloRunner::ReadModuleFromHloProtoFile(filename, debug_options); + if (module.ok()) { + return module; + } + const std::string e = module.status().error_message(); + module = HloRunner::ReadModuleFromHloTextDumpFile(filename, debug_options); + return module.ok() ? std::move(module) + : Status(module.status().code(), + e + "\n" + module.status().error_message()); +} + +// Define this in .cc file to avoid having to include eigen or forward declare +// these types in the header. +struct HloRunner::EigenThreadPoolWrapper { + std::unique_ptr pool; + std::unique_ptr device; +}; + +HloRunner::HloRunner() {} + +HloRunner::HloRunner(se::Platform* platform) { + BackendOptions backend_options; + backend_options.set_platform(platform); + backend_ = Backend::CreateBackend(backend_options).ConsumeValueOrDie(); + VLOG(1) << "Created HloRunner for platform: " << platform->Name(); +} + +HloRunner::~HloRunner() { + // Deallocate all the memory allocated during the tests. + for (auto& allocation : allocations_) { + backend().default_stream_executor()->Deallocate(&allocation); + } +} + +StatusOr HloRunner::Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments, + Shape* result_shape) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr executable, + backend().compiler()->Compile(std::move(module), + backend().default_stream_executor())); + + se::Stream stream(backend().default_stream_executor()); + stream.Init(); + + ExecutableRunOptions run_options; + run_options.set_stream(&stream); + run_options.set_allocator(backend().memory_allocator()); + run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); + run_options.set_intra_op_thread_pool( + backend().eigen_intra_op_thread_pool_device()); + + HloExecutionProfile hlo_execution_profile; + ServiceExecutableRunOptions service_run_options( + run_options, backend().StreamBorrower(), + backend().inter_op_thread_pool()); + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase result, + executable->ExecuteOnStream(&service_run_options, arguments, + &hlo_execution_profile)); + TF_RET_CHECK(stream.BlockHostUntilDone()); + + allocations_.push_back(result); + + *result_shape = executable->result_shape(); + + if (ShapeUtil::IsTuple(*result_shape)) { + // We must record element buffers of tuples as well to avoid leaks. + DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); + TF_ASSIGN_OR_RETURN( + std::vector element_buffers, + backend().transfer_manager()->ShallowCopyTupleFromDevice( + backend().default_stream_executor(), result, *result_shape)); + + // A tuple may contain the same buffer in more than one element. Keep track + // of the buffers already added to avoid duplicates in allocations_. + std::set added_opaques; + for (auto element_buffer : element_buffers) { + if (added_opaques.count(element_buffer.opaque()) == 0) { + CHECK(element_buffer.opaque() != nullptr); + added_opaques.insert(element_buffer.opaque()); + allocations_.push_back(element_buffer); + } + } + } + + return result; +} + +StatusOr HloRunner::TransferToDevice( + const Literal& literal) { + // Allocate memory on the device using the stream executor. + int64 allocation_size = + backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); + se::DeviceMemoryBase allocation = + backend().default_stream_executor()->AllocateArray( + allocation_size); + allocations_.push_back(allocation); + + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralToDevice( + backend().default_stream_executor(), literal, &allocation)); + + return allocation; +} + +StatusOr> HloRunner::TransferFromDevice( + const Shape& shape, se::DeviceMemoryBase device_base) { + auto literal = MakeUnique(); + TF_RETURN_IF_ERROR(backend().transfer_manager()->TransferLiteralFromDevice( + backend().default_stream_executor(), device_base, shape, shape, + literal.get())); + return std::move(literal); +} + +StatusOr> HloRunner::ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice arguments) { + Shape result_shape; + TF_ASSIGN_OR_RETURN(se::DeviceMemoryBase device_base, + Execute(std::move(module), arguments, &result_shape)); + return TransferFromDevice(result_shape, device_base); +} + +Backend& HloRunner::backend() { + if (!backend_) { + backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); + VLOG(1) << "executing on platform " << backend().platform()->Name(); + } + return *backend_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_runner.h b/tensorflow/compiler/xla/service/hlo_runner.h new file mode 100644 index 0000000000000000000000000000000000000000..a5732848c6b4191faf8d7b07c749132ca8b14413 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_runner.h @@ -0,0 +1,127 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/backend.h" +#include "tensorflow/compiler/xla/service/compiler.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/stream_executor_no_cuda.h" + +namespace xla { + +// A base class for running an HloModule. This executes the given HloModule on a +// certain backend directly without using the client interface. HloModule can be +// explicitly built, or loaded from a serialization file (e.g., hlo proto file). +class HloRunner { + public: + HloRunner(); + + HloRunner(::perftools::gputools::Platform* platform); + + ~HloRunner(); + + // Reads the proto file in xla.HloProto format, creates and returns the + // HloModule. Will try to parse the filename as binary proto, then try as + // text proto if that fails. + static StatusOr> ReadModuleFromHloProtoFile( + const std::string& filename, const DebugOptions& debug_options); + + // Reads the hlo text dump file in HloModule::ToString format, creates and + // returns the HloModule. + static StatusOr> ReadModuleFromHloTextDumpFile( + const std::string& filename, const DebugOptions& debug_options); + + // Tries to parse the filename specified first as binary proto format, then + // as a textual proto format, then textual IR, then gives up if both fail. + // ReadModuleFromHloProtoFile or ReadModuleFromHloTextDumpFile should be used + // explicitly when you know the format, this if you don't. + static StatusOr> ReadModule( + const std::string& filename, const DebugOptions& debug_options); + + // Executes the given module with given literals as input and returns the + // result as a Literal. The LiteralPtr type accepts Literal* or + // std::unique_ptr. + template + StatusOr> Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice literals); + + // Executes the given module and returns a global data handle. + StatusOr Execute( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments, + Shape* result_shape); + + // Transfers the given literal to the device and returns the data handle. + StatusOr TransferToDevice( + const Literal& literal); + + // Transfers the array referred to by the given handle from the device and + // returns as a Literal. + StatusOr> TransferFromDevice( + const Shape& shape, perftools::gputools::DeviceMemoryBase device_base); + + // Executes the given module and return the result as a Literal. + StatusOr> ExecuteAndTransfer( + std::unique_ptr module, + tensorflow::gtl::ArraySlice + arguments); + + // If backend is not created in the constructor, creates and returns the + // default backend. If creation fails, crashes the program. + // + // This creates the backend lazily so it's possible to instantiate an + // HloRunner in a program without any backends linked in. + Backend& backend(); + + private: + struct EigenThreadPoolWrapper; + + std::vector allocations_; + + std::unique_ptr thread_pool_wrapper_; + + std::unique_ptr backend_; +}; + +template +StatusOr> HloRunner::Execute( + std::unique_ptr module, + const tensorflow::gtl::ArraySlice literals) { + std::vector arguments; + for (const auto& literal : literals) { + TF_ASSIGN_OR_RETURN(perftools::gputools::DeviceMemoryBase argument, + TransferToDevice(*literal)); + arguments.push_back(argument); + } + return ExecuteAndTransfer(std::move(module), arguments); +} + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_RUNNER_H_ diff --git a/tensorflow/compiler/xla/service/hlo_scheduling.cc b/tensorflow/compiler/xla/service/hlo_scheduling.cc index 25be448c8d186514e5d5d04382f4733fee3af68b..8ccbcaeee4a9c9e94b344231953e20ac8f4b2053 100644 --- a/tensorflow/compiler/xla/service/hlo_scheduling.cc +++ b/tensorflow/compiler/xla/service/hlo_scheduling.cc @@ -97,7 +97,7 @@ class ListScheduler { // instruction. An HLO instruction "uses" a LogicalBuffer if the // LogicalBuffer is in an operand of the instruction as indicated by // points-to analysis. - for (auto& instruction : computation.instructions()) { + for (auto* instruction : computation.instructions()) { std::unordered_set instr_uses; for (auto* operand : instruction->operands()) { for (const LogicalBuffer* buffer : @@ -105,20 +105,20 @@ class ListScheduler { instr_uses.insert(buffer); } } - buffer_uses_[instruction.get()] = std::vector( + buffer_uses_[instruction] = std::vector( instr_uses.begin(), instr_uses.end()); } // Create map containing the number of unscheduled uses (hlo instructions) // of each logical buffer. - for (auto& instruction : computation.instructions()) { - for (auto* buffer : points_to_analysis.GetBuffersDefinedByInstruction( - instruction.get())) { + for (auto* instruction : computation.instructions()) { + for (auto* buffer : + points_to_analysis.GetBuffersDefinedByInstruction(instruction)) { unscheduled_use_count_[buffer] = 0; } } - for (auto& instruction : computation.instructions()) { - for (const LogicalBuffer* buffer : buffer_uses_.at(instruction.get())) { + for (auto* instruction : computation.instructions()) { + for (const LogicalBuffer* buffer : buffer_uses_.at(instruction)) { ++unscheduled_use_count_[buffer]; } } @@ -204,7 +204,7 @@ class ListScheduler { // Populate the ready list with instructions which have no operands or // control predecessors. std::unordered_map unscheduled_pred_count; - for (auto& instruction : computation_.instructions()) { + for (auto* instruction : computation_.instructions()) { // TODO(b/34466113): Replace this and above with successors() or // predecessors() when these methods are added to HloInstruction. for (const HloInstruction* user : instruction->users()) { @@ -216,11 +216,11 @@ class ListScheduler { } std::list ready_list; - for (auto& instruction : computation_.instructions()) { + for (auto* instruction : computation_.instructions()) { // Instruction with no operands or control predecessors will // not be in the map. - if (unscheduled_pred_count.count(instruction.get()) == 0) { - ready_list.push_back(MakeReadyListEntry(instruction.get())); + if (unscheduled_pred_count.count(instruction) == 0) { + ready_list.push_back(MakeReadyListEntry(instruction)); } } @@ -267,9 +267,8 @@ class ListScheduler { update_pred_count(succ); } } - CHECK_EQ(schedule.size(), computation_.instructions().size()); - CHECK_EQ(scheduled_instructions_.size(), - computation_.instructions().size()); + CHECK_EQ(schedule.size(), computation_.instruction_count()); + CHECK_EQ(scheduled_instructions_.size(), computation_.instruction_count()); return schedule; } @@ -327,8 +326,8 @@ StatusOr> RunDFSMemoryScheduler( total_sizes[hlo] += total_sizes[operand]; } } - CHECK_EQ(extra_users.size(), computation.instructions().size()); - CHECK_EQ(total_sizes.size(), computation.instructions().size()); + CHECK_EQ(extra_users.size(), computation.instruction_count()); + CHECK_EQ(total_sizes.size(), computation.instruction_count()); // Construct a total order based on DFS post-order, visiting operands in // decreasing cumulative extra user order, and next by cumulative size, with a @@ -349,7 +348,7 @@ StatusOr> RunDFSMemoryScheduler( } return a->name() < b->name(); })); - CHECK_EQ(sequence.size(), computation.instructions().size()); + CHECK_EQ(sequence.size(), computation.instruction_count()); return sequence; } @@ -411,11 +410,8 @@ CreateMemoryMinimizingSequence( SequentialHloOrdering::HloModuleSequence sequence; TF_ASSIGN_OR_RETURN(std::unique_ptr points_to_analysis, TuplePointsToAnalysis::Run(&module)); - for (const auto& computation : module.computations()) { - if (computation->IsFusionComputation()) { - continue; - } - TF_ASSIGN_OR_RETURN(sequence[computation.get()], + for (const auto* computation : module.MakeNonfusionComputations()) { + TF_ASSIGN_OR_RETURN(sequence[computation], CreateMemoryMinimizingSequence( *computation, *points_to_analysis, size_function)); } diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc new file mode 100644 index 0000000000000000000000000000000000000000..0d019d22f5d4cd401c0fc5572f99636dec4f7383 --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding.cc @@ -0,0 +1,232 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +#include "tensorflow/core/lib/core/errors.h" + +namespace xla { + +using ::tensorflow::strings::StrCat; + +HloSharding HloSharding::AssignDevice(int64 device_id) { + return HloSharding(device_id); +} + +HloSharding HloSharding::Tile1D(const Shape& input_shape, int64 num_tiles) { + CHECK_EQ(1, ShapeUtil::Rank(input_shape)); + CHECK_GT(num_tiles, 1); + std::vector dimensions(1, num_tiles); + Shape tile_shape = input_shape; + auto& tile_dimension = (*tile_shape.mutable_dimensions())[0]; + tile_dimension = CeilOfRatio(static_cast(tile_dimension), num_tiles); + Array assignment(dimensions); + std::iota(assignment.begin(), assignment.end(), 0); + return HloSharding(tile_shape, assignment); +} + +string HloSharding::ToString() const { + string result = StrCat("{", (replicated_ ? " replicated" : ""), + (maximal_ ? " maximal" : "")); + + if (replicated_) { + return "{replicated}"; + } else if (maximal_) { + return StrCat( + "{maximal device=", static_cast(*tile_assignment_.begin()), "}"); + } else { + return StrCat("{", ShapeUtil::HumanString(tile_shape_), " ", + "devices=", VectorString(tile_assignment_), "}"); + } +} + +bool HloSharding::UsesDevice(int64 device) const { + const auto& devices = tile_assignment_; + return replicated_ || + std::find(devices.begin(), devices.end(), device) != devices.end(); +} + +std::vector HloSharding::TileIndexForDevice(int64 device) const { + CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!maximal_); + std::vector ret_index; + tile_assignment_.Each([&](tensorflow::gtl::ArraySlice index, int64 d) { + if (d == device) { + ret_index = {index.begin(), index.end()}; + } + }); + CHECK(!ret_index.empty()); + return ret_index; +} + +int64 HloSharding::DeviceForTileIndex( + tensorflow::gtl::ArraySlice index) const { + CHECK(!replicated_); + if (maximal_) { + return *tile_assignment_.begin(); + } + CHECK_EQ(ShapeUtil::Rank(tile_shape_), tile_assignment_.dimensions().size()); + return tile_assignment_(index); +} + +std::vector HloSharding::TileOffsetForDevice(int64 device) const { + CHECK(!ShapeUtil::IsTuple(tile_shape_)); + + std::vector index = TileIndexForDevice(device); + if (maximal_) { + // Index will always be all zeroes if we're maximal, and tile_shape_ is not + // valid. + return index; + } + for (int64 i = 0; i < index.size(); ++i) { + index[i] *= tile_shape_.dimensions(i); + } + return index; +} + +std::vector HloSharding::TileLimitForDevice(int64 device) const { + CHECK(!ShapeUtil::IsTuple(tile_shape_)); + CHECK(!maximal_); // Maximal shardings do not have a valid tile shape. + + std::vector index = TileIndexForDevice(device); + for (int64 i = 0; i < index.size(); ++i) { + index[i] = (index[i] + 1) * tile_shape_.dimensions(i); + } + return index; +} + +StatusOr HloSharding::UniqueDevice() const { + if (!replicated_ && maximal_) { + return static_cast(*tile_assignment_.begin()); + } + return tensorflow::errors::InvalidArgument( + "UniqueDevice() called on sharding that executes on multiple devices"); +} + +Status HloSharding::Validate(const Shape& shape, int64 num_devices) const { + if (replicated_) { + return Status::OK(); + } + + // All tile assignments must be less than the number of available cores and + // unique. + Status status = Status::OK(); + std::set seen_cores; + tile_assignment_.Each( + [&](tensorflow::gtl::ArraySlice indices, uint32 core) { + // Don't overwrite a bad status, so we report the first error. + if (status.ok()) { + if (core >= num_devices) { + status = + tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( + "core ", core, " > ", num_devices, " in tile assignment")); + } else if (seen_cores.count(core) != 0) { + status = + tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( + "core ", core, " is not unique in tile assignment")); + } + } + seen_cores.insert(core); + }); + if (!status.ok()) { + return status; + } + + if (IsTileMaximal()) { + return Status::OK(); + } + + // The tile rank must be the same as the input rank. + if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) { + return tensorflow::errors::InvalidArgument( + "Tile rank is different to the input rank"); + } + + // The tile shape must not be the same as the input shape without maximal_ + // also set. If this is the case, we're not actually sharded and the correct + // constructor should have been used. + if (ShapeUtil::Equal(shape, tile_shape_)) { + return tensorflow::errors::InvalidArgument( + "Tile shape is the same as the input shape. If a replicated sharding " + "was intended, use HloSharding::Replicated(). If a device placement " + "was intended, use HloSharding::AssignDevice()"); + } + + // The tile shape must not be greater than the input shape in any dimension. + for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) { + auto tile_dim = tile_shape_.dimensions(i); + auto shape_dim = shape.dimensions(i); + if (tile_dim > shape_dim) { + return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( + "Tile is larger than input shape (dimension ", i, ", ", tile_dim, + " > ", shape_dim)); + } + } + + // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim] + // tile[dim]) for every dimension contained within tile. + for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) { + int64 expected_dim = + CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i)); + if (tile_assignment_.dimensions()[i] != expected_dim) { + return tensorflow::errors::InvalidArgument(tensorflow::strings::StrCat( + "Tile assignment tensor has incorrect shape. Dimension ", i, + " expected ", expected_dim, " but got ", + tile_assignment_.dimensions()[i])); + } + } + + return Status::OK(); +} + +/*static*/ StatusOr HloSharding::FromProto( + const OpSharding& proto) { + if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) { + return Replicate(); + } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL) { + return HloSharding(proto.tile_assignment_devices(0)); + } + // Some versions of gcc cannot infer the TileAssignment constructor from a + // braced initializer-list, so create one manually. + std::vector devices(proto.tile_assignment_devices().begin(), + proto.tile_assignment_devices().end()); + Array tile_assignment( + std::vector(proto.tile_assignment_dimensions().begin(), + proto.tile_assignment_dimensions().end())); + std::copy(proto.tile_assignment_devices().begin(), + proto.tile_assignment_devices().end(), tile_assignment.begin()); + return HloSharding(proto.tile_shape(), tile_assignment); +} + +OpSharding HloSharding::ToProto() const { + OpSharding result; + *result.mutable_tile_shape() = tile_shape_; + for (int64 dim : tile_assignment_.dimensions()) { + result.add_tile_assignment_dimensions(dim); + } + for (auto device : tile_assignment_) { + result.add_tile_assignment_devices(device); + } + if (IsReplicated()) { + result.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + } else if (IsTileMaximal()) { + result.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + } else { + result.set_type(OpSharding::Type::OpSharding_Type_OTHER); + } + return result; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_sharding.h b/tensorflow/compiler/xla/service/hlo_sharding.h new file mode 100644 index 0000000000000000000000000000000000000000..d7ada30c70bc3b41b3117375380eac2e883d9a9d --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding.h @@ -0,0 +1,165 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// HLO shardings describe how an HLO instruction is split across multiple +// computations. + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ + +#include + +#include "tensorflow/compiler/xla/array.h" +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/protobuf_util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/lib/hash/hash.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +// HLO shardings describe how an HLO instruction is split across multiple +// computations. +class HloSharding { + public: + // Creates a trivial sharding that replicates a maximal tile across all + // devices. + static HloSharding Replicate() { return HloSharding(); } + + // Creates a sharding that emulates device placement; a tile shape equal to + // the input shape (one tile) assigned to a single device. + static HloSharding AssignDevice(int64 device_id); + + // Creates a new sharding which splits a shape into tiles each with shape + // `tile_shape`. Each tile is assigned to one device, which is specified by + // `tile_assignment`. Any tensor not a multiple of the tile size in any + // dimension is implicitly padded to the tile size. + // + // e.g. Tile({2, 2}, {0, 1}) on a tensor of shape {3, 2} would look like: + // 2 1 padding + // <------><-> + // +----+----+ + // | 0 | 1 | + // +----+----+ + // + // Split into two tiles, one of which is implicitly padded by one. + static HloSharding Tile(const Shape& tile_shape, + const Array& tile_assignment) { + return HloSharding(tile_shape, tile_assignment); + } + + // Creates a new sharding which splits a one-dimensional input shape into + // `num_tiles` tiles. + static HloSharding Tile1D(const Shape& input_shape, int64 num_tiles); + + // Create a new sharding from a protobuf OpSharding. + static StatusOr FromProto(const OpSharding& proto); + + OpSharding ToProto() const; + string ToString() const; + + // Validate that this sharding can be applied to a tensor with shape `shape`. + Status Validate(const Shape& shape, int64 num_devices) const; + + // Returns true if the sharding is trivial: replicate on all devices. + bool IsReplicated() const { return replicated_; } + + // Returns true if the tile size is the same as the input size. + bool IsTileMaximal() const { return maximal_; } + + // Returns true if the sharding defines an operation on the given device. + bool UsesDevice(int64 device) const; + + // Returns the tile that should be executed on the given device. + std::vector TileIndexForDevice(int64 device) const; + + // Returns the device that should execute the given tile. + // It is an error to call this if is_replicated() is true. + int64 DeviceForTileIndex(tensorflow::gtl::ArraySlice index) const; + + // Given a device ID, returns the offset within the input space of the + // tile that should be executed on the given core. This returns the lower + // extent of the tile in the input space. + std::vector TileOffsetForDevice(int64 device) const; + + // Given a device ID, returns the limit within the input space of the + // tile that should be executed on the given core. This returns the upper + // extent of the tile in the input space. + std::vector TileLimitForDevice(int64 device) const; + + // Returns the single device this op operates on. + // Requires !Replicated() && IsTileMaximal(). + StatusOr UniqueDevice() const; + + // Returns true if this op only uses a single device. + bool HasUniqueDevice() const { return !IsReplicated() && IsTileMaximal(); } + + bool operator==(const HloSharding& other) const { + return replicated_ == other.replicated_ && maximal_ == other.maximal_ && + protobuf_util::ProtobufEquals(tile_shape_, other.tile_shape_) && + tile_assignment_ == other.tile_assignment_; + } + bool operator!=(const HloSharding& other) const { return !(*this == other); } + + size_t Hash() const { + if (replicated_) { + return 0; + } + size_t h = 0; + for (uint32 v : tile_assignment_) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + for (uint32 v : tile_shape_.dimensions()) { + h = tensorflow::Hash64Combine(h, std::hash{}(v)); + } + return h; + } + + // Gets the tile shape. + // It is an error to call this if IsTileMaximal() is true. + const Shape& tile_shape() const { return tile_shape_; } + // Gets the tile assignment tensor. + // It is an error to call this if IsReplicated() is true. + const Array& tile_assignment() const { return tile_assignment_; } + + private: + HloSharding() + : replicated_(true), + maximal_(true), + tile_shape_(), + tile_assignment_({0}) {} + explicit HloSharding(int64 device_id) + : replicated_(false), + maximal_(true), + tile_shape_(), + tile_assignment_({1}, device_id) {} + HloSharding(const Shape& tile_shape, const Array& tile_assignment) + : replicated_(false), + maximal_(false), + tile_shape_(tile_shape), + tile_assignment_(tile_assignment) {} + + bool replicated_; + bool maximal_; + Shape tile_shape_; + Array tile_assignment_; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_SHARDING_H_ diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d0a20471a0f22a5fa414b71bb5160eed7cdc431b --- /dev/null +++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc @@ -0,0 +1,190 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/hlo_sharding.h" + +#include +#include +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/test_helpers.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/util.h" + +namespace xla { +namespace { + +Array MakeArray(tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice contents) { + Array a(dimensions); + std::copy(contents.begin(), contents.end(), a.begin()); + return a; +} + +class HloShardingTest : public HloTestBase {}; + +TEST_F(HloShardingTest, Replicate) { + Shape tile_shape = ShapeUtil::MakeShape(U32, {4}); + HloSharding sharding = HloSharding::Replicate(); + EXPECT_TRUE(sharding.IsReplicated()); + EXPECT_TRUE(sharding.IsTileMaximal()); + EXPECT_TRUE(sharding.UsesDevice(0)); + EXPECT_TRUE(sharding.UsesDevice(65535)); + + HloSharding other = HloSharding::Replicate(); + EXPECT_EQ(other, sharding); + + EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), + /*num_devices=*/2)); + EXPECT_IS_NOT_OK(sharding.UniqueDevice()); +} + +TEST_F(HloShardingTest, DevicePlacement) { + HloSharding sharding = HloSharding::AssignDevice(5); + EXPECT_FALSE(sharding.IsReplicated()); + EXPECT_TRUE(sharding.IsTileMaximal()); + EXPECT_FALSE(sharding.UsesDevice(0)); + EXPECT_TRUE(sharding.UsesDevice(5)); + EXPECT_EQ(5, sharding.UniqueDevice().ValueOrDie()); + + HloSharding other = HloSharding::Replicate(); + EXPECT_NE(other, sharding); + + EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4}), + /*num_devices=*/6)); + EXPECT_IS_NOT_OK( + sharding.Validate(ShapeUtil::MakeShape(U32, {4}), /*num_devices=*/5)); +} + +TEST_F(HloShardingTest, Tile) { + { + // Test should fail because of a duplicate tile assignment. + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 0, 2, 3})); + EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {4, 6}), + /*num_devices=*/4)); + } + + { + // Test should pass. + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(U32, {4, 6}), + /*num_devices=*/2)); + } + + { + // Test should fail due to the tile being larger than the input space. + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}), + /*num_devices=*/4)); + } + + { + // Test should fail due to the tile not dividing the input space into 4 + // sections (even with padding). + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3})); + EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {6, 3}), + /*num_devices=*/4)); + } + + { + // Test should pass. + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + EXPECT_IS_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {3, 5}), + /*num_devices=*/5)); + + EXPECT_EQ(0, sharding.DeviceForTileIndex({0, 0})); + EXPECT_EQ(3, sharding.DeviceForTileIndex({0, 1})); + EXPECT_EQ(2, sharding.DeviceForTileIndex({1, 0})); + EXPECT_EQ(1, sharding.DeviceForTileIndex({1, 1})); + + EXPECT_EQ(sharding.TileOffsetForDevice(0), (std::vector{0, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(3), (std::vector{0, 3})); + EXPECT_EQ(sharding.TileOffsetForDevice(2), (std::vector{2, 0})); + EXPECT_EQ(sharding.TileOffsetForDevice(1), (std::vector{2, 3})); + + EXPECT_IS_NOT_OK(sharding.UniqueDevice()); + } +} + +TEST_F(HloShardingTest, Hash) { + auto hash_compare_equal = [](const HloSharding& a, const HloSharding& b) { + if (a.Hash() != b.Hash()) { + return false; + } + return a == b; + }; + + { + HloSharding sharding1 = HloSharding::Replicate(); + HloSharding sharding2 = HloSharding::Replicate(); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + HloSharding sharding1 = HloSharding::AssignDevice(1); + HloSharding sharding2 = HloSharding::AssignDevice(1); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + HloSharding sharding1 = HloSharding::AssignDevice(1); + HloSharding sharding2 = HloSharding::AssignDevice(2); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } + + { + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding1 = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), + MakeArray({2, 2}, {0, 3, 2, 1})); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding1 = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), + MakeArray({2, 2}, {0, 3, 2, 1})); + EXPECT_TRUE(hash_compare_equal(sharding1, sharding2)); + } + + { + Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3}); + HloSharding sharding1 = + HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 3, 2, 1})); + HloSharding sharding2 = HloSharding::Tile(ShapeUtil::MakeShape(U32, {2, 3}), + MakeArray({2, 2}, {0, 3, 1, 2})); + EXPECT_FALSE(hash_compare_equal(sharding1, sharding2)); + } +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc index 460dc5cf640edaa647a0627fe4bf9359845ac41d..8b332f23ae98d480d272190ca01cc5270033b0b2 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification.cc @@ -25,10 +25,10 @@ StatusOr HloSubcomputationUnification::Run(HloModule* module) { std::unordered_map canon; const auto& computations = module->computations(); for (auto i = computations.begin(); i != computations.end(); ++i) { - for (auto j = computations.begin(); j < i; ++j) { + for (auto j = computations.begin(); j != i; ++j) { // Do not waste time comparing `*i` with `*j` if `*j` is not canonical. - if (canon.find(j->get()) == canon.end() && **i == **j) { - canon[i->get()] = j->get(); + if (canon.find(*j) == canon.end() && **i == **j) { + canon[*i] = *j; break; } } diff --git a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc index 33b3634cfc12f4d117e2db1b8a6640b1c979f538..7b601f9a9578cfa6b293cf7f002255f7db8b1257 100644 --- a/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc +++ b/tensorflow/compiler/xla/service/hlo_subcomputation_unification_test.cc @@ -85,7 +85,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, module->computations().size()); + EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), @@ -98,7 +98,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyIdentities) { "after unification", module->config().debug_options()); } - EXPECT_EQ(2, module->computations().size()); + EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -124,7 +124,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, module->computations().size()); + EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), @@ -137,7 +137,7 @@ TEST_F(HloSubcomputationUnificationTest, UnifyAdditions) { "after unification", module->config().debug_options()); } - EXPECT_EQ(2, module->computations().size()); + EXPECT_EQ(2, module->computation_count()); EXPECT_EQ(x->to_apply(), y->to_apply()); } @@ -164,7 +164,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { module->AddEntryComputation(builder.Build()); - EXPECT_EQ(3, module->computations().size()); + EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); if (VLOG_IS_ON(1)) { hlo_graph_dumper::DumpGraph(*module->entry_computation(), @@ -177,7 +177,7 @@ TEST_F(HloSubcomputationUnificationTest, DifferentParameterShapes) { "after unification", module->config().debug_options()); } - EXPECT_EQ(3, module->computations().size()); + EXPECT_EQ(3, module->computation_count()); EXPECT_NE(x->to_apply(), y->to_apply()); } @@ -201,8 +201,8 @@ TEST_F(HloSubcomputationUnificationTest, TwoIdenticalComputations) { } EXPECT_TRUE(HloSubcomputationUnification().Run(module.get()).ValueOrDie()); - EXPECT_EQ(1, module->computations().size()); - EXPECT_EQ(module->computations().front().get(), module->entry_computation()); + EXPECT_EQ(1, module->computation_count()); + EXPECT_EQ(*module->computations().begin(), module->entry_computation()); } } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc index 5a4c93b59a6810b962e3c8f54b2964dffa8ecd6d..06abe007477dbcd00bcdc7f2656c4dece6d1cf74 100644 --- a/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc +++ b/tensorflow/compiler/xla/service/hlo_tfgraph_builder.cc @@ -56,6 +56,8 @@ TensorShapeProto GetTensorShape(const HloInstruction* instruction) { return tensor_shape; } +string GetDeviceName(int device) { return StrCat("/device/XLA:", device); } + } // namespace void CleanNodeName(string* name) { @@ -71,12 +73,12 @@ void CleanNodeName(string* name) { Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) { VLOG(2) << "Adding computation " << computation.name(); for (auto embedded : computation.MakeEmbeddedComputationsList()) { - for (auto& instruction : embedded->instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + for (auto* instruction : embedded->instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction)); } } - for (auto& instruction : computation.instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(instruction.get())); + for (auto* instruction : computation.instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(instruction)); } return Status::OK(); } @@ -178,6 +180,10 @@ void HloTfGraphBuilder::SetNodeAttrs(const HloInstruction* instruction, case HloOpcode::kCustomCall: attrs["custom_call_target"].set_s(instruction->custom_call_target()); break; + case HloOpcode::kSend: + case HloOpcode::kRecv: + attrs["channel_id"].set_i(instruction->channel_id()); + break; default: break; } @@ -192,10 +198,15 @@ Status HloTfGraphBuilder::AddInstruction(const HloInstruction* instruction) { NodeDef* node_def = graph_def_.add_node(); node_def->set_name(GetNodeNameForInstruction(instruction)); node_def->set_op(GetOpDefName(instruction)); + if (instruction->has_sharding() && + instruction->sharding().HasUniqueDevice()) { + TF_ASSIGN_OR_RETURN(int64 device, instruction->sharding().UniqueDevice()); + node_def->set_device(GetDeviceName(device)); + } SetNodeAttrs(instruction, node_def); if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& fused_instruction : instruction->fused_instructions()) { - TF_RETURN_IF_ERROR(AddInstruction(fused_instruction.get())); + for (auto* fused_instruction : instruction->fused_instructions()) { + TF_RETURN_IF_ERROR(AddInstruction(fused_instruction)); } } // Add all edges including control edges. diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 2405d4477832d2717e3b4a15ece94596370858b5..c1aa655401a2be68af943e2ed29c4ab99d341383 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -40,22 +40,17 @@ class ShapeVerifier : public DfsHloVisitor { return CheckBinaryShape(hlo); } - Status HandleClamp(HloInstruction* clamp, HloInstruction* min, - HloInstruction* arg, HloInstruction* max) override { + Status HandleClamp(HloInstruction* clamp) override { return CheckTernaryShape(clamp); } - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override { + Status HandleSelect(HloInstruction* select) override { return CheckTernaryShape(select); } - Status HandleConcatenate( - HloInstruction* concatenate, - tensorflow::gtl::ArraySlice operands) override { + Status HandleConcatenate(HloInstruction* concatenate) override { std::vector operand_shapes; - for (const HloInstruction* operand : operands) { + for (const HloInstruction* operand : concatenate->operands()) { operand_shapes.push_back(&operand->shape()); } return CheckShape( @@ -64,6 +59,10 @@ class ShapeVerifier : public DfsHloVisitor { } Status HandleConvert(HloInstruction* convert) override { + if (ShapeUtil::ElementIsComplex(convert->operand(0)->shape())) { + TF_RET_CHECK(ShapeUtil::ElementIsComplex(convert->shape())) + << "Unsupported complex->real kConvert"; + } return CheckShape(convert, ShapeInference::InferConvertShape( convert->operand(0)->shape(), convert->shape().element_type())); @@ -73,17 +72,17 @@ class ShapeVerifier : public DfsHloVisitor { return CheckUnaryShape(copy); } - Status HandleDot(HloInstruction* dot, HloInstruction* lhs, - HloInstruction* rhs) override { + Status HandleDot(HloInstruction* dot) override { return CheckBinaryShape(dot); } - Status HandleConvolution(HloInstruction* convolution, HloInstruction* lhs, - HloInstruction* rhs, const Window& window) override { - TF_ASSIGN_OR_RETURN(const Shape expected, - ShapeInference::InferConvolveShape( - lhs->shape(), rhs->shape(), window, - convolution->convolution_dimension_numbers())); + Status HandleConvolution(HloInstruction* convolution) override { + TF_ASSIGN_OR_RETURN( + const Shape expected, + ShapeInference::InferConvolveShape( + convolution->operand(0)->shape(), convolution->operand(1)->shape(), + convolution->window(), + convolution->convolution_dimension_numbers())); return CheckShape(convolution, expected); } @@ -100,47 +99,40 @@ class ShapeVerifier : public DfsHloVisitor { reduce_precision->mantissa_bits())); } - Status HandleInfeed(HloInstruction* infeed) override { + Status HandleInfeed(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleOutfeed(HloInstruction* outfeed) override { + Status HandleOutfeed(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleRng(HloInstruction* random, - RandomDistribution distribution) override { + Status HandleRng(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleReverse(HloInstruction* reverse, - HloInstruction* operand) override { + Status HandleReverse(HloInstruction* reverse) override { return CheckShape( reverse, ShapeInference::InferReverseShape(reverse->operand(0)->shape(), reverse->dimensions())); } - Status HandleSort(HloInstruction* sort, HloInstruction* operand) override { + Status HandleSort(HloInstruction* sort) override { return CheckUnaryShape(sort); } - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override { - return CheckShape(constant, literal.shape()); + Status HandleConstant(HloInstruction* constant) override { + return CheckShape(constant, constant->literal().shape()); } - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override { + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override { return CheckShape(get_tuple_element, ShapeInference::InferGetTupleElementShape( get_tuple_element->operand(0)->shape(), get_tuple_element->tuple_index())); } - Status HandleReduce(HloInstruction* reduce, HloInstruction* arg, - HloInstruction* init_value, - tensorflow::gtl::ArraySlice dimensions, - HloComputation* function) override { + Status HandleReduce(HloInstruction* reduce) override { return CheckShape( reduce, ShapeInference::InferReduceShape( @@ -183,11 +175,11 @@ class ShapeVerifier : public DfsHloVisitor { transpose->dimensions())); } - Status HandleParameter(HloInstruction* parameter) override { + Status HandleParameter(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleFusion(HloInstruction* fusion) override { + Status HandleFusion(HloInstruction*) override { return tensorflow::Status::OK(); } @@ -196,32 +188,26 @@ class ShapeVerifier : public DfsHloVisitor { return CheckShape(call, call->to_apply()->ComputeProgramShape().result()); } - Status HandleCustomCall(HloInstruction* custom_call, - tensorflow::gtl::ArraySlice operands, - tensorflow::StringPiece custom_call_target) override { + Status HandleCustomCall(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleSlice(HloInstruction* slice, HloInstruction* operand) override { + Status HandleSlice(HloInstruction* slice) override { return CheckShape(slice, ShapeInference::InferSliceShape( slice->operand(0)->shape(), slice->slice_starts(), slice->slice_limits(), slice->slice_strides())); } - Status HandleDynamicSlice(HloInstruction* dynamic_slice, - HloInstruction* operand, - HloInstruction* start_indices) override { + Status HandleDynamicSlice(HloInstruction* dynamic_slice) override { return CheckShape(dynamic_slice, ShapeInference::InferDynamicSliceShape( dynamic_slice->operand(0)->shape(), dynamic_slice->operand(1)->shape(), dynamic_slice->dynamic_slice_sizes())); } - Status HandleDynamicUpdateSlice(HloInstruction* dynamic_update_slice, - HloInstruction* operand, - HloInstruction* update, - HloInstruction* start_indices) override { + Status HandleDynamicUpdateSlice( + HloInstruction* dynamic_update_slice) override { return CheckShape(dynamic_update_slice, ShapeInference::InferDynamicUpdateSliceShape( dynamic_update_slice->operand(0)->shape(), @@ -229,29 +215,29 @@ class ShapeVerifier : public DfsHloVisitor { dynamic_update_slice->operand(2)->shape())); } - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override { + Status HandleTuple(HloInstruction* tuple) override { return CheckVariadicShape(tuple); } - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice static_operands) override { + Status HandleMap(HloInstruction* map) override { std::vector operand_shapes; - for (const HloInstruction* operand : operands) { + int64 max_operand_rank = 0; + for (const HloInstruction* operand : map->operands()) { operand_shapes.push_back(&operand->shape()); + max_operand_rank = + std::max(max_operand_rank, ShapeUtil::Rank(operand->shape())); } + // TODO(b/65689298) Remove code below once Map is generalized to accept + // arbitrary map dimensions. + std::vector map_dims(max_operand_rank); + std::iota(map_dims.begin(), map_dims.end(), 0); return CheckShape( - map, ShapeInference::InferMapShape( - operand_shapes, map->to_apply()->ComputeProgramShape())); + map, + ShapeInference::InferMapShape( + operand_shapes, map->to_apply()->ComputeProgramShape(), map_dims)); } - Status HandleReduceWindow(HloInstruction* reduce_window, - HloInstruction* operand, const Window& window, - HloComputation* function) override { + Status HandleReduceWindow(HloInstruction* reduce_window) override { return CheckShape( reduce_window, ShapeInference::InferReduceWindowShape( @@ -284,11 +270,11 @@ class ShapeVerifier : public DfsHloVisitor { pad->padding_config())); } - Status HandleSend(HloInstruction* send) override { + Status HandleSend(HloInstruction*) override { return tensorflow::Status::OK(); } - Status HandleRecv(HloInstruction* recv) override { + Status HandleRecv(HloInstruction*) override { return tensorflow::Status::OK(); } @@ -323,7 +309,7 @@ class ShapeVerifier : public DfsHloVisitor { batch_norm_grad->feature_index())); } - Status FinishVisit(HloInstruction* root) override { + Status FinishVisit(HloInstruction*) override { return tensorflow::Status::OK(); } @@ -347,7 +333,10 @@ class ShapeVerifier : public DfsHloVisitor { Status CheckShape(const HloInstruction* instruction, const StatusOr& expected_shape_status) { if (!expected_shape_status.ok()) { - return expected_shape_status.status(); + Status s = expected_shape_status.status(); + tensorflow::errors::AppendToMessage(&s, ", for instruction ", + instruction->ToString()); + return s; } return CheckShape(instruction, expected_shape_status.ValueOrDie()); } @@ -407,8 +396,8 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { fusion->fused_parameters(); const HloInstruction* fused_root = fusion->fused_expression_root(); std::vector parameter_owned(fused_parameters.size(), false); - for (auto& instruction : fused_computation->instructions()) { - if (fused_root == instruction.get()) { + for (auto* instruction : fused_computation->instructions()) { + if (fused_root == instruction) { if (root_owned) { return FailedPrecondition("Root appears more than once in %s.", fusion->ToString().c_str()); @@ -416,7 +405,7 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { root_owned = true; } for (int i = 0; i < fused_parameters.size(); ++i) { - if (fused_parameters[i] == instruction.get()) { + if (fused_parameters[i] == instruction) { if (parameter_owned[i]) { return FailedPrecondition("Parameter appears more than once in %s.", fusion->ToString().c_str()); @@ -445,9 +434,9 @@ Status HloVerifier::CheckFusionInstruction(HloInstruction* fusion) const { // All uses of fused instructions must be in the fusion computation, and every // non-root instruction must have at least one use. - for (auto& instruction : + for (auto* instruction : fusion->fused_instructions_computation()->instructions()) { - if (instruction.get() != fused_root) { + if (instruction != fused_root) { if (instruction->user_count() == 0) { return FailedPrecondition( "Non-root instruction %s in %s must have users.", @@ -511,11 +500,11 @@ StatusOr HloVerifier::Run(HloModule* module) { tensorflow::gtl::FlatMap instructions; ShapeVerifier shape_verifier(shape_size_fn_); - for (auto& computation : module->computations()) { + for (auto* computation : module->computations()) { for (const auto& instruction : computation->instructions()) { - TF_RET_CHECK(instruction->parent() == computation.get()); + TF_RET_CHECK(instruction->parent() == computation); if (instruction->opcode() == HloOpcode::kFusion) { - TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction.get())); + TF_RETURN_IF_ERROR(CheckFusionInstruction(instruction)); TF_RET_CHECK( ContainersEqual(instruction->called_computations(), {instruction->fused_instructions_computation()})) @@ -532,10 +521,9 @@ StatusOr HloVerifier::Run(HloModule* module) { instruction->fused_instructions_computation()) << "Fused HLO was missing a parent: " << fused->ToString() << " parent: " << fused->parent() - << " computation: " << computation.get(); + << " computation: " << computation; } - } - if (instruction->opcode() == HloOpcode::kBroadcast) { + } else if (instruction->opcode() == HloOpcode::kBroadcast) { // If you see this failure then someone has confused the difference // between the HLO broadcast op, and the UserComputation broadcast // op. See https://groups.google.com/forum/#!topic/xla-dev/9LqijHmTt_I @@ -543,6 +531,40 @@ StatusOr HloVerifier::Run(HloModule* module) { TF_RET_CHECK(instruction->dimensions().size() == ShapeUtil::Rank(instruction->operand(0)->shape())) << "Broadcast HLO has invalid number of dimensions."; + } else if (instruction->opcode() == HloOpcode::kWhile) { + auto* while_cond = instruction->while_condition(); + auto* while_body = instruction->while_body(); + TF_RET_CHECK(while_cond->num_parameters() == 1) + << "While condition must have exactly 1 parameter; had " + << while_cond->num_parameters() << ": " << while_cond->ToString(); + TF_RET_CHECK(while_body->num_parameters() == 1) + << "While body must have exactly 1 parameter; had " + << while_body->num_parameters() << ": " << while_body->ToString(); + TF_RET_CHECK(instruction->operand_count() == 1) + << "While loop must have exactly one operand; had " + << instruction->operand_count() << ": " << instruction->ToString(); + + auto* init = instruction->operand(0); + auto* cond_param = while_cond->parameter_instruction(0); + TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), cond_param->shape())) + << "While condition's parameter must have the same shape as the " + "loop's 'init'. init: " + << init->ToString() << ", param: " << cond_param->ToString(); + auto* cond_root = while_cond->root_instruction(); + TF_RET_CHECK(ShapeUtil::Compatible(cond_root->shape(), + ShapeUtil::MakeShape(PRED, {}))) + << "While condition should have shape PRED: " + << cond_root->ToString(); + + auto* body_param = while_body->parameter_instruction(0); + TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_param->shape())) + << "While body's parameter must have the same shape as the loop's " + "'init'. init: " + << init->ToString() << ", param: " << body_param->ToString(); + auto* body_root = while_body->root_instruction(); + TF_RET_CHECK(ShapeUtil::Compatible(init->shape(), body_root->shape())) + << "While body should have same shape as the loop's 'init'. init: " + << init->ToString() << ", body: " << body_root->ToString(); } auto previous = instructions.find(instruction->name()); @@ -553,7 +575,7 @@ StatusOr HloVerifier::Run(HloModule* module) { << "\nPrevious HLO with same name:\n" << previous->second->ToString() << " in computation: " << previous->second->parent()->name(); - instructions[instruction->name()] = instruction.get(); + instructions[instruction->name()] = instruction; } TF_RETURN_IF_ERROR(computation->Accept(&shape_verifier)); diff --git a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc index d620f45d27eba706fbd7fc30d3b27b0d963475d4..b7c40fdeeb157fc74900bd9cf9d68a06a2cb1d56 100644 --- a/tensorflow/compiler/xla/service/human_readable_profile_builder.cc +++ b/tensorflow/compiler/xla/service/human_readable_profile_builder.cc @@ -68,12 +68,20 @@ string HumanReadableProfileBuilder::ToString() const { }; float optimal_seconds_sum = 0.0; + int64 total_flops = 0.; + int64 total_transcendentals = 0.; + int64 total_bytes = 0; for (const auto& op : op_infos_) { optimal_seconds_sum += op.optimal_seconds; + total_flops += op.flop_count; + total_transcendentals += op.transcendental_count; + total_bytes += op.bytes_accessed; } - append_op({"[total]", "[total]", /*category=*/"", total_cycles_, -1, -1, -1, - optimal_seconds_sum}); + VLOG(1) << "Total floating point ops: " << total_flops; + + append_op({"[total]", "[total]", /*category=*/"", total_cycles_, total_flops, + total_transcendentals, total_bytes, optimal_seconds_sum}); // Sort ops in decreasing order of cycles. std::vector sorted_ops(op_infos_); diff --git a/tensorflow/compiler/xla/service/inliner.cc b/tensorflow/compiler/xla/service/inliner.cc index 382ebd8008c2b3e1bb016005f198d36771b1ae60..5c193fceb984448cf0532d7e1010281268614293 100644 --- a/tensorflow/compiler/xla/service/inliner.cc +++ b/tensorflow/compiler/xla/service/inliner.cc @@ -43,11 +43,7 @@ class InlinerVisitor : public DfsHloVisitorWithDefault { return Status::OK(); } - Status HandleMap( - HloInstruction* map, - tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice static_operands) override; + Status HandleMap(HloInstruction* map) override; // Runs the visitor on a computation. StatusOr Run(HloComputation* computation); @@ -67,18 +63,14 @@ StatusOr InlinerVisitor::Run(HloComputation* computation) { return changed_; } -Status InlinerVisitor::HandleMap( - HloInstruction* map, tensorflow::gtl::ArraySlice operands, - HloComputation* function, - tensorflow::gtl::ArraySlice /*static_operands*/) { +Status InlinerVisitor::HandleMap(HloInstruction* map) { + HloComputation* function = map->to_apply(); HloInstruction& root = *function->root_instruction(); // TODO(b/29249531): Add DCE pass to remove unused HloComputations. // Only inlining functions that are simply a single operation until a better // profitability model for inlining is defined. if (hlo_query::AllOperandsAreParameters(root)) { - if (root.opcode() == HloOpcode::kUpdate || - root.opcode() == HloOpcode::kFusion || - root.opcode() == HloOpcode::kIndex || + if (root.opcode() == HloOpcode::kFusion || root.opcode() == HloOpcode::kParameter || root.opcode() == HloOpcode::kTrace) { // Cloning not supported for these instructions. @@ -90,8 +82,12 @@ Status InlinerVisitor::HandleMap( // different than the map shape. Hence, a broadcast is needed, else the // cloned operand with new shape and operands work. if (root.opcode() != HloOpcode::kConstant) { + std::vector params; + for (int64 o = 0; o < root.operands().size(); o++) { + params.push_back(map->operands()[root.operand(o)->parameter_number()]); + } HloInstruction* placed_instruction = computation_->AddInstruction( - root.CloneWithNewOperands(map->shape(), operands)); + root.CloneWithNewOperands(map->shape(), params)); TF_RETURN_IF_ERROR( computation_->ReplaceInstruction(map, placed_instruction)); } else { @@ -113,10 +109,8 @@ Status InlinerVisitor::HandleMap( StatusOr Inliner::Run(HloModule* module) { InlinerVisitor visitor(/*computation=*/nullptr); bool changed = false; - for (const std::unique_ptr& computation : - module->computations()) { - TF_ASSIGN_OR_RETURN(bool computation_changed, - visitor.Run(computation.get())); + for (HloComputation* computation : module->computations()) { + TF_ASSIGN_OR_RETURN(bool computation_changed, visitor.Run(computation)); changed |= computation_changed; } return changed; diff --git a/tensorflow/compiler/xla/service/inliner_test.cc b/tensorflow/compiler/xla/service/inliner_test.cc index 9d845c5545680f1a5389fa89af02eb723203a5f7..7aa1c7c8358318d02a000d968a2672123400ad6e 100644 --- a/tensorflow/compiler/xla/service/inliner_test.cc +++ b/tensorflow/compiler/xla/service/inliner_test.cc @@ -108,5 +108,44 @@ TEST_F(InlinerTest, MapConstant) { LiteralTestUtil::ExpectEqual(*result, *expected); } +TEST_F(InlinerTest, MapSubtractOppositeOrder) { + Shape r0f32 = ShapeUtil::MakeShape(F32, {}); + + // Note that the parameter ordinals are in the opposite order to their + // position as operands + auto max_builder = HloComputation::Builder(TestName()); + auto param1 = max_builder.AddInstruction( + HloInstruction::CreateParameter(1, r0f32, "x")); + auto param2 = max_builder.AddInstruction( + HloInstruction::CreateParameter(0, r0f32, "y")); + max_builder.AddInstruction(HloInstruction::CreateBinary( + param1->shape(), HloOpcode::kSubtract, param1, param2)); + auto max_f32 = max_builder.Build(); + + auto builder = HloComputation::Builder("MapSubFunction"); + auto lhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({1, 2, 3, 4}))); + auto rhs = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({4, 3, 2, 1}))); + builder.AddInstruction( + HloInstruction::CreateMap(lhs->shape(), {lhs, rhs}, max_f32.get())); + + auto computation = builder.Build(); + auto hlo_module = CreateNewModule(); + hlo_module->AddEmbeddedComputation(std::move(max_f32)); + hlo_module->AddEntryComputation(std::move(computation)); + + Inliner inliner; + EXPECT_TRUE(inliner.Run(hlo_module.get()).ValueOrDie()); + EXPECT_THAT(hlo_module->entry_computation()->root_instruction(), + op::Subtract(rhs, lhs)); + + // Verify execution on CPU. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + auto expected = Literal::CreateR1({3, 1, -1, -3}); + LiteralTestUtil::ExpectEqual(*result, *expected); +} + + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 573c0d16bc6b7da271bf509389d1544ebd046adb..0d1b7bc109c56bc4290ede09284c6d20142bdb08 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -32,17 +32,16 @@ namespace xla { const HloInstruction& instruction) { switch (instruction.opcode()) { // Cheap instructions. - case HloOpcode::kAbs: case HloOpcode::kAdd: case HloOpcode::kBitcast: case HloOpcode::kBroadcast: case HloOpcode::kCeil: case HloOpcode::kClamp: + case HloOpcode::kComplex: case HloOpcode::kConcatenate: case HloOpcode::kConstant: case HloOpcode::kConvert: case HloOpcode::kCopy: - case HloOpcode::kCos: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: case HloOpcode::kEq: @@ -50,12 +49,13 @@ namespace xla { case HloOpcode::kGe: case HloOpcode::kGetTupleElement: case HloOpcode::kGt: + case HloOpcode::kImag: case HloOpcode::kInfeed: case HloOpcode::kIsFinite: case HloOpcode::kLe: - case HloOpcode::kLogicalAnd: - case HloOpcode::kLogicalNot: - case HloOpcode::kLogicalOr: + case HloOpcode::kAnd: + case HloOpcode::kNot: + case HloOpcode::kOr: case HloOpcode::kLt: case HloOpcode::kMaximum: case HloOpcode::kMinimum: @@ -64,20 +64,30 @@ namespace xla { case HloOpcode::kNegate: case HloOpcode::kOutfeed: case HloOpcode::kPad: + case HloOpcode::kReal: case HloOpcode::kReducePrecision: case HloOpcode::kReshape: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: case HloOpcode::kSelect: - case HloOpcode::kSign: - case HloOpcode::kSin: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: case HloOpcode::kSlice: case HloOpcode::kSubtract: case HloOpcode::kTranspose: case HloOpcode::kTuple: return false; + // Cheap instructions for reals, but expensive for complex. + case HloOpcode::kAbs: + case HloOpcode::kCos: + case HloOpcode::kSign: + case HloOpcode::kSin: + return ShapeUtil::ElementIsComplex(instruction.shape()); + // Expensive instructions. + case HloOpcode::kAtan2: case HloOpcode::kBatchNormTraining: case HloOpcode::kBatchNormInference: case HloOpcode::kBatchNormGrad: @@ -89,7 +99,6 @@ namespace xla { case HloOpcode::kDot: case HloOpcode::kExp: case HloOpcode::kFusion: - case HloOpcode::kIndex: case HloOpcode::kLog: case HloOpcode::kMap: case HloOpcode::kParameter: @@ -102,7 +111,6 @@ namespace xla { case HloOpcode::kSort: case HloOpcode::kTanh: case HloOpcode::kTrace: - case HloOpcode::kUpdate: case HloOpcode::kWhile: case HloOpcode::kSend: case HloOpcode::kRecv: @@ -203,16 +211,12 @@ bool InstructionFusion::CanFuseOnAllPaths( } StatusOr InstructionFusion::Run(HloModule* module) { + VLOG(2) << "Before instruction fusion:"; + XLA_VLOG_LINES(2, module->ToString()); + bool changed = false; module_ = module; - std::vector computations; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - computations.push_back(computation.get()); - } - for (auto& computation : computations) { + for (auto* computation : module->MakeNonfusionComputations()) { CHECK(!computation->IsFusionComputation()); computation_ = computation; @@ -378,6 +382,10 @@ StatusOr InstructionFusion::Run(HloModule* module) { } } } + + VLOG(2) << "After instruction fusion:"; + XLA_VLOG_LINES(2, module->ToString()); + return changed; } diff --git a/tensorflow/compiler/xla/service/interpreter/BUILD b/tensorflow/compiler/xla/service/interpreter/BUILD index 40d6040b307605f9f697260f5ef240944937ca72..2704a805a91b93c69b751cdb61305ea7780f0ef2 100644 --- a/tensorflow/compiler/xla/service/interpreter/BUILD +++ b/tensorflow/compiler/xla/service/interpreter/BUILD @@ -7,6 +7,22 @@ load( "if_static", ) +cc_library( + name = "interpreter_transfer_manager", + srcs = ["interpreter_transfer_manager.cc"], + hdrs = ["interpreter_transfer_manager.h"], + deps = [ + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:generic_transfer_manager", + "//tensorflow/compiler/xla/service:transfer_manager", + "//tensorflow/compiler/xla/service/interpreter:platform_id", + "//tensorflow/core:lib", + "//tensorflow/core:stream_executor_no_cuda", + ], + alwayslink = True, # Contains per-platform transfer manager registration +) + cc_library( name = "compiler", srcs = ["compiler.cc"], @@ -36,8 +52,8 @@ cc_library( "//tensorflow/compiler/xla/service:inliner", "//tensorflow/compiler/xla/service:layout_assignment", "//tensorflow/compiler/xla/service:reshape_mover", + "//tensorflow/compiler/xla/service:while_loop_simplifier", "//tensorflow/core:lib", - "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/stream_executor", ], alwayslink = True, # Contains compiler registration diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.cc b/tensorflow/compiler/xla/service/interpreter/compiler.cc index c8d02834f43a747980d084be37602bc56db74b98..6d5796a24b5209355debd80b912b7fa62d40837c 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.cc +++ b/tensorflow/compiler/xla/service/interpreter/compiler.cc @@ -32,6 +32,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/interpreter/executable.h" #include "tensorflow/compiler/xla/service/layout_assignment.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -56,6 +57,7 @@ Status InterpreterCompiler::RunHloOptimization(HloModule* hlo_module) { pipeline.AddPass>( false, [](const Shape&, const Shape&) { return false; }); + pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(); pipeline.AddPass(true); @@ -88,7 +90,7 @@ StatusOr> InterpreterCompiler::Compile( StatusOr>> InterpreterCompiler::Compile( std::vector> /*hlo_modules*/, - std::vector /*stream_execs*/) { + std::vector> /*stream_execs*/) { return tensorflow::errors::Unimplemented( "Compilation of multiple HLO modules is not supported on Interpreter."); } diff --git a/tensorflow/compiler/xla/service/interpreter/compiler.h b/tensorflow/compiler/xla/service/interpreter/compiler.h index 13db38ab60a07bdf476227c9b9e818dfe2cdcc6c..cfdc9b6256569b0137784b0d1db846a5f2339a5d 100644 --- a/tensorflow/compiler/xla/service/interpreter/compiler.h +++ b/tensorflow/compiler/xla/service/interpreter/compiler.h @@ -49,7 +49,8 @@ class InterpreterCompiler : public Compiler { StatusOr>> Compile( std::vector> hlo_modules, - std::vector stream_exec) override; + std::vector> + stream_exec) override; StatusOr>> CompileAheadOfTime(std::vector> hlo_modules, diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc index 989fc4e0313e9390ee54f40912f24036433d1f36..86dee8462fd4fdda580ada892e244f19177fb3e5 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.cc +++ b/tensorflow/compiler/xla/service/interpreter/executable.cc @@ -41,7 +41,7 @@ namespace se = ::perftools::gputools; namespace sep = ::perftools::gputools::interpreter; InterpreterExecutable::InterpreterExecutable( - std::unique_ptr hlo_module) + std::unique_ptr hlo_module) : Executable(std::move(hlo_module)) {} InterpreterExecutable::~InterpreterExecutable() {} diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h index 2881d6697e23645917a1813ab2a43fee5e710d57..c69b0d036d1058a6b24ee609a9923895d3246eec 100644 --- a/tensorflow/compiler/xla/service/interpreter/executable.h +++ b/tensorflow/compiler/xla/service/interpreter/executable.h @@ -40,7 +40,7 @@ namespace interpreter { // buffer allocation. Refer to interpreter/README.md for more. class InterpreterExecutable : public Executable { public: - InterpreterExecutable(std::unique_ptr hlo_module); + InterpreterExecutable(std::unique_ptr hlo_module); ~InterpreterExecutable() override; StatusOr ExecuteOnStream( diff --git a/tensorflow/compiler/xla/service/interpreter_transfer_manager.cc b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc similarity index 86% rename from tensorflow/compiler/xla/service/interpreter_transfer_manager.cc rename to tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc index 1864dcdf0367f36eac0367a39cdd3e0ec014be63..cf98ecd7749d61261bf072cdb1882c7603f39556 100644 --- a/tensorflow/compiler/xla/service/interpreter_transfer_manager.cc +++ b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/xla/service/interpreter_transfer_manager.h" +#include "tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h" #include @@ -26,7 +26,8 @@ namespace sei = ::perftools::gputools::interpreter; namespace xla { InterpreterTransferManager::InterpreterTransferManager() - : GenericTransferManager(sei::kInterpreterPlatformId) {} + : GenericTransferManager(sei::kInterpreterPlatformId, + /*pointer_size=*/sizeof(void*)) {} } // namespace xla diff --git a/tensorflow/compiler/xla/service/interpreter_transfer_manager.h b/tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h similarity index 100% rename from tensorflow/compiler/xla/service/interpreter_transfer_manager.h rename to tensorflow/compiler/xla/service/interpreter/interpreter_transfer_manager.h diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index 57c15ef48e18a32edc1ad340083f1840fe6572f8..7eda7c2284c2457703fcfcd4226172e41dd4ae01 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -98,7 +98,7 @@ string ResultLayoutConstraint::ToString() const { LayoutConstraints::LayoutConstraints( const TuplePointsToAnalysis& points_to_analysis, - const HloComputation* computation) + HloComputation* computation) : points_to_analysis_(points_to_analysis), computation_(computation) { // Gather all array-shaped logical buffers into unconstrained_buffer_ids. for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers(); @@ -376,7 +376,7 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted @@ -384,13 +384,13 @@ Status LayoutAssignment::AddMandatoryConstraints( // TODO(b/31425034): Change infeeds to be more like parameters, with // shapes in the ComputationLayout. DCHECK(!LayoutUtil::IsPadded(instruction->shape())); - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(instruction->shape(), - instruction.get())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(instruction->shape(), instruction)); } else if (instruction->opcode() == HloOpcode::kOutfeed) { // Constrain the input to the Outfeed instruction to be the expected // layout of the Outfeed. TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - instruction->outfeed_shape(), instruction.get(), 0, + instruction->outfeed_shape(), instruction, 0, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kParameter) { // Parameter layouts must match the respective layout in @@ -400,8 +400,8 @@ Status LayoutAssignment::AddMandatoryConstraints( .shape(); } if (shape_with_layout != nullptr) { - TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(*shape_with_layout, - instruction.get())); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(*shape_with_layout, instruction)); } } @@ -409,21 +409,20 @@ Status LayoutAssignment::AddMandatoryConstraints( // already been assigned layouts. Instructions which call computations in a // parallel element-wise context (eg, map or reduce) do not need layout // constraints because they operate on scalars. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kCall) { // kCall instruction operands and output must match the ComputationLayout // of the called computation. const ComputationLayout& called_computation_layout = FindOrDie(computation_layouts_, instruction->to_apply()); TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - called_computation_layout.result_layout().shape(), - instruction.get())); + called_computation_layout.result_layout().shape(), instruction)); TF_RET_CHECK(instruction->operand_count() == called_computation_layout.parameter_count()); for (int64 i = 0; i < instruction->operand_count(); ++i) { TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - called_computation_layout.parameter_layout(i).shape(), - instruction.get(), i, /*mandatory=*/true)); + called_computation_layout.parameter_layout(i).shape(), instruction, + i, /*mandatory=*/true)); } } else if (instruction->opcode() == HloOpcode::kWhile) { // Layout of input and output of kWhile instruction must be equal and must @@ -472,9 +471,9 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain the output and the operand of the while instruction to match // the computations. TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( - body_layout.result_shape(), instruction.get())); + body_layout.result_shape(), instruction)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - body_layout.result_shape(), instruction.get(), 0, + body_layout.result_shape(), instruction, 0, /*mandatory=*/true)); } else if (instruction->opcode() == HloOpcode::kCustomCall) { // Add constraints for kCustomCall instruction operands and instructions. @@ -489,7 +488,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape result_shape(row_major_shape(instruction->shape())); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction.get())); + constraints->SetInstructionLayout(result_shape, instruction)); for (int64 i = 0; i < instruction->operand_count(); ++i) { const Shape& operand_shape = instruction->operand(i)->shape(); // Opaque operands don't get a layout constraint. @@ -499,7 +498,7 @@ Status LayoutAssignment::AddMandatoryConstraints( Shape row_major_operand_shape(row_major_shape(operand_shape)); TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction.get(), i, /*mandatory=*/true)); + row_major_operand_shape, instruction, i, /*mandatory=*/true)); } } } @@ -609,11 +608,8 @@ Status CheckLayouts( const std::map& computation_layouts) { TF_ASSIGN_OR_RETURN(auto points_to_analysis, TuplePointsToAnalysis::Run(module)); - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - for (auto& instruction : computation->instructions()) { + for (auto* computation : module->MakeNonfusionComputations()) { + for (auto* instruction : computation->instructions()) { // Verify every instruction has a layout and the layout is valid for the // shape. TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape())); @@ -623,7 +619,7 @@ Status CheckLayouts( // output of the instruction matches the layout of the logical buffer // which could be the source of the subshape value. const PointsToSet& points_to_set = - points_to_analysis->GetPointsToSet(instruction.get()); + points_to_analysis->GetPointsToSet(instruction); TF_RETURN_IF_ERROR(points_to_set.ForEachElementWithStatus( [&instruction](ShapeIndex index, const PointsToSet::BufferList& buffers) -> Status { @@ -652,26 +648,26 @@ Status CheckLayouts( switch (instruction->opcode()) { case HloOpcode::kCall: TF_RETURN_IF_ERROR(CheckCallLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: - TF_RETURN_IF_ERROR(CheckFusionLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); break; case HloOpcode::kParameter: TF_RETURN_IF_ERROR(CheckParameterLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->parent()))); break; case HloOpcode::kConstant: - TF_RETURN_IF_ERROR(CheckConstantLayout(instruction.get())); + TF_RETURN_IF_ERROR(CheckConstantLayout(instruction)); break; case HloOpcode::kWhile: TF_RETURN_IF_ERROR(CheckWhileLayout( - instruction.get(), + instruction, FindOrDie(computation_layouts, instruction->while_condition()), FindOrDie(computation_layouts, instruction->while_body()))); break; @@ -736,7 +732,8 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( // dimension bound is 1 in the operand shape, there may be several such // layouts. So if 'output_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy - // operations. + // operations. For similar reasons, if the operand and output have the same + // rank, try to match the operand's layout to the output. if (ShapeUtil::TrueRank(operand->shape()) == 1 && ShapeUtil::Rank(instruction->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. @@ -752,6 +749,13 @@ std::unique_ptr LayoutAssignment::ChooseOperandLayoutFromOutputLayout( if (ShapeUtil::ReshapeIsBitcast(operand_shape, output_shape_with_layout)) { return MakeUnique(operand_shape.layout()); } + if (ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape)) { + *operand_shape.mutable_layout() = output_layout; + if (ShapeUtil::ReshapeIsBitcast(operand_shape, + output_shape_with_layout)) { + return MakeUnique(output_layout); + } + } auto aligned_operand_shape = ShapeUtil::AlignLayouts(output_shape_with_layout, operand_shape); if (aligned_operand_shape) { @@ -800,7 +804,8 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( // dimension bound is 1 in the user shape, there may be several such // layouts. So if 'operand_layout' is the default layout, try if the // reshape is a bitcast when using the same layout. This may avoid copy - // operations. + // operations. For similar reasons, if the operand and output have the same + // rank, try to match the outputs's layout to the operand. if (ShapeUtil::Rank(operand->shape()) == 1 && ShapeUtil::TrueRank(user->shape()) == 1) { // Don't assign a layout in case of R1 -> effective R1 reshape. @@ -816,6 +821,13 @@ std::unique_ptr LayoutAssignment::ChooseOutputLayoutFromOperandLayout( if (ShapeUtil::ReshapeIsBitcast(output_shape, operand_shape_with_layout)) { return MakeUnique(output_shape.layout()); } + if (ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape)) { + *output_shape.mutable_layout() = operand_layout; + if (ShapeUtil::ReshapeIsBitcast(output_shape, + operand_shape_with_layout)) { + return MakeUnique(operand_layout); + } + } auto aligned_user_shape = ShapeUtil::AlignLayouts(operand_shape_with_layout, output_shape); if (aligned_user_shape) { @@ -1184,11 +1196,9 @@ Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout, // to match the layout of its corresponding fusion instruction operand. Also, // set the layout of the fused root to match the layout of the fusion // instruction itself. -// Fused GetTupleElement requires a layout so that TBAA metadata for the tuple -// element array pointer load can be added. Status SetFusionLayouts(HloInstruction* fusion) { TF_RET_CHECK(fusion->opcode() == HloOpcode::kFusion); - for (auto& fused_instruction : fusion->fused_instructions()) { + for (auto* fused_instruction : fusion->fused_instructions()) { if (fused_instruction->opcode() == HloOpcode::kParameter) { const HloInstruction* fusion_operand = fusion->operand(fused_instruction->parameter_number()); @@ -1196,7 +1206,7 @@ Status SetFusionLayouts(HloInstruction* fusion) { fused_instruction->shape())); TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes( fusion_operand->shape(), fused_instruction->mutable_shape())); - } else if (fused_instruction.get() == fusion->fused_expression_root()) { + } else if (fused_instruction == fusion->fused_expression_root()) { // The layout of the root of the fused expression must match the fusion // instruction layout. DCHECK( diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 118d68dc476c23c6945ec38e061617e7e29f357c..0b97fba744923b8afc3fb539566b68f1bca47d38 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -121,10 +121,11 @@ class ResultLayoutConstraint : public LayoutConstraint { class LayoutConstraints { public: LayoutConstraints(const TuplePointsToAnalysis& points_to_analysis, - const HloComputation* computation); + HloComputation* computation); ~LayoutConstraints() = default; const HloComputation* computation() const { return computation_; } + HloComputation* computation() { return computation_; } const TuplePointsToAnalysis& points_to_analysis() const { return points_to_analysis_; } @@ -211,7 +212,7 @@ class LayoutConstraints { // Array-shaped buffers which have not yet been constrained. std::set unconstrained_buffer_ids_; - const HloComputation* computation_; + HloComputation* computation_; }; // HLO pass which assigns layouts to all instructions in the HLO module while diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD index 86817b05f52f254184c53d9bcb6dd8ca14a7d39b..075d4a1ab5e5f39394ade393d21525ca3e97136e 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/BUILD +++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD @@ -45,6 +45,7 @@ cc_library( "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:types", + "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", @@ -92,7 +93,6 @@ cc_library( deps = [ ":ir_array", ":llvm_loop", - ":ops", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", @@ -111,7 +111,7 @@ cc_library( ":ir_array", ":llvm_util", ":loop_emitter", - ":ops", + ":tuple_ops", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:util", @@ -127,6 +127,23 @@ cc_library( name = "ops", srcs = ["ops.cc"], hdrs = ["ops.h"], + deps = [ + ":fused_ir_emitter", + ":ir_array", + ":llvm_util", + ":loop_emitter", + "//tensorflow/compiler/xla/service:buffer_assignment", + "//tensorflow/compiler/xla/service:elemental_ir_emitter", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/gpu:parallel_loop_emitter", + "//tensorflow/compiler/xla/service/gpu:partition_assignment", + ], +) + +cc_library( + name = "tuple_ops", + srcs = ["tuple_ops.cc"], + hdrs = ["tuple_ops.h"], deps = [ ":ir_array", ":llvm_util", diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc index 5e28e37600c18a351e8647d48119f073277f56e1..bdddc232ef74dfa37e2d5cc780b0fe11e7bc8e76 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc @@ -92,7 +92,16 @@ void AliasAnalysis::AddAliasingInformationToIrArray(const HloInstruction& hlo, llvm::MDNode* AliasAnalysis::GetAliasDomain() { llvm::MDBuilder metadata_builder(*context_); if (alias_domain_ == nullptr) { - alias_domain_ = metadata_builder.createAnonymousAliasScopeDomain(); + // We use createAliasScopeDomain rather than createAnonymousAliasScopeDomain + // so that when functions get inlined, we continue using the one domain, + // rather than duplicating it (and thus having two AA domains in one + // function). + // + // A side-effect of this is that if you ever compile two HLO modules in the + // same LLVM module, they'll have the same alias scope domain. This isn't a + // problem because the two HLO modules will never interact with one another. + alias_domain_ = + metadata_builder.createAliasScopeDomain("XLA global AA domain"); } return alias_domain_; } diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc index 7d1fad753e0d94f7d88b824ed57d52890a48b1dd..bc683a1880b010d57e83aa6e9ffa95fda299e1a0 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.cc @@ -22,7 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -72,10 +72,10 @@ Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) { return Status::OK(); } -Status FusedIrEmitter::HandleConstant(HloInstruction* constant, - const Literal& literal) { +Status FusedIrEmitter::HandleConstant(HloInstruction* constant) { + const Literal& literal = constant->literal(); llvm::Constant* initializer = - llvm_ir::ConvertLiteralToIrConstant(literal, ir_builder_); + llvm_ir::ConvertLiteralToIrConstant(literal, module_); llvm::GlobalVariable* global = new llvm::GlobalVariable( *ir_builder_->GetInsertBlock()->getModule(), initializer->getType(), /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, initializer, @@ -88,9 +88,10 @@ Status FusedIrEmitter::HandleConstant(HloInstruction* constant, return Status::OK(); } -Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) { +Status FusedIrEmitter::HandleGetTupleElement( + HloInstruction* get_tuple_element) { // Lookup ir value for 'operand'. + auto operand = get_tuple_element->operand(0); auto it = gte_values_.find(operand); if (it == gte_values_.end()) { return Unimplemented( @@ -101,7 +102,7 @@ Status FusedIrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element, // Emit code to lookup tuple element pointer, and store it in 'gte_values_'. llvm::Value* tuple_element_ptr = llvm_ir::EmitGetTupleElement( get_tuple_element->shape(), get_tuple_element->tuple_index(), - /*alignment=*/1, it->second, ir_builder_); + /*alignment=*/1, it->second, ir_builder_, module_); gte_values_.insert(std::make_pair(get_tuple_element, tuple_element_ptr)); // Emit code to read base tuple element array (if non-tuple shaped). if (!ShapeUtil::IsTuple(get_tuple_element->shape())) { @@ -128,13 +129,12 @@ Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) { return Status::OK(); } -Status FusedIrEmitter::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) { + tensorflow::gtl::ArraySlice operands(tuple->operands()); std::vector operand_elemental_ir_types; for (HloInstruction* operand : operands) { operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType( - operand->shape().element_type(), ir_builder_)); + operand->shape().element_type(), module_)); } generators_[tuple] = [=](const IrArray::Index& index) -> StatusOr { diff --git a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h index a24e104067f19e45ab2566beedbb8217913bad12..9ad7cd82cb8ca862fd7acec3dfb12c9fd61f6e27 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h +++ b/tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h @@ -42,22 +42,19 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { ElementalIrEmitter* elemental_emitter) : parameter_arrays_(parameter_arrays), elemental_emitter_(elemental_emitter), - ir_builder_(elemental_emitter->ir_builder()) {} + ir_builder_(elemental_emitter->ir_builder()), + module_(elemental_emitter->module()) {} Status DefaultAction(HloInstruction* hlo) override; - Status HandleConstant(HloInstruction* constant, - const Literal& literal) override; + Status HandleConstant(HloInstruction* constant) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleParameter(HloInstruction* parameter) override; // Emits the ir value for each element in the tuple. - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; + Status HandleTuple(HloInstruction* tuple) override; Status FinishVisit(HloInstruction* root) override; @@ -85,6 +82,7 @@ class FusedIrEmitter : public DfsHloVisitorWithDefault { // Borrowed llvm::IRBuilder<>* ir_builder_; + llvm::Module* module_; // Map from instruction pointers to functions to generate elements of their // outputs diff --git a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc index e36c791c1a52f4e0699cc15ef913fbd2bdcca557..e3f98ac13e76f0df465066422ca7918a0f218b60 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ir_array.cc @@ -229,9 +229,11 @@ llvm::Value* IrArray::EmitArrayElementAddress( } if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) { + llvm::Module* module = + ir_builder->GetInsertBlock()->getParent()->getParent(); return ir_builder->CreateInBoundsGEP( ir_builder->CreateBitCast( - base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), ir_builder) + base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module) ->getPointerTo()), {index.linear()}, llvm_ir::AsStringRef(name)); } @@ -268,8 +270,6 @@ llvm::Value* IrArray::EmitReadArrayElement(const Index& index, llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder, name); llvm::LoadInst* load = ir_builder->CreateLoad(element_address); - llvm_ir::SetTbaaForInstruction(load, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(load); return load; } @@ -278,14 +278,13 @@ void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value, llvm::IRBuilder<>* ir_builder) const { llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder); llvm::StoreInst* store = ir_builder->CreateStore(value, element_address); - llvm_ir::SetTbaaForInstruction(store, GetShape(), - /*is_pointer_to=*/false); AnnotateLoadStoreInstructionWithMetadata(store); } IrArray IrArray::CastToShape(const Shape& new_shape, llvm::IRBuilder<>* ir_builder) const { - llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, ir_builder); + llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module); return IrArray( ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape); diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc index 51c4ac9be14f66aec1b14a93bcc39637fc6eacba..956c0d5f05288e32c626f247ce8356c60d17808d 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.cc @@ -16,8 +16,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include +#include #include +#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/MDBuilder.h" #include "llvm/IR/Operator.h" #include "llvm/Target/TargetOptions.h" @@ -25,15 +27,31 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/core/casts.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" namespace xla { namespace llvm_ir { +namespace { + +// Note, this function is only useful in an insertion context; in a global +// (e.g. constants) context it will CHECK fail. +llvm::Module* ModuleFromIRBuilder(llvm::IRBuilder<>* ir_builder) { + auto block = CHECK_NOTNULL(ir_builder->GetInsertBlock()); + auto fn = CHECK_NOTNULL(block->getParent()); + auto module = CHECK_NOTNULL(fn->getParent()); + return module; +} + +} // namespace + string AsString(const std::string& str) { return string(str.data(), str.length()); } @@ -59,7 +77,7 @@ llvm::Value* EmitCallToIntrinsic( for (auto type : overloaded_types) { types.push_back(type); } - llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent(); + llvm::Module* module = ModuleFromIRBuilder(ir_builder); llvm::Function* intrinsic = llvm::Intrinsic::getDeclaration(module, intrinsic_id, types); std::vector operands_vec; @@ -115,38 +133,53 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, } llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { switch (element_type) { case PRED: case S8: case U8: - return ir_builder->getInt8Ty(); + return llvm::Type::getInt8Ty(module->getContext()); case S16: case U16: - return ir_builder->getInt16Ty(); + return llvm::Type::getInt16Ty(module->getContext()); case S32: case U32: - return ir_builder->getInt32Ty(); + return llvm::Type::getInt32Ty(module->getContext()); case S64: case U64: - return ir_builder->getInt64Ty(); + return llvm::Type::getInt64Ty(module->getContext()); case F32: - return ir_builder->getFloatTy(); + return llvm::Type::getFloatTy(module->getContext()); case F64: - return ir_builder->getDoubleTy(); + return llvm::Type::getDoubleTy(module->getContext()); + case C64: { + auto cplx_t = module->getTypeByName("complex64"); + if (cplx_t == nullptr) { + // C++ standard dictates the memory layout of std::complex is contiguous + // real followed by imaginary. C++11 section 26.4 [complex.numbers]: + // If z is an lvalue expression of type cv std::complex then the + // expression reinterpret_cast(z) shall be well-formed, + // reinterpret_cast(z)[0] shall designate the real part of + // z, and reinterpret_cast(z)[1] shall designate the + // imaginary part of z. + return llvm::StructType::create( + "complex64", llvm::Type::getFloatTy(module->getContext()), + llvm::Type::getFloatTy(module->getContext())); + } + return cplx_t; + } // A Tuple contains an array of pointers. Use i8*. case TUPLE: // An Opaque is like a void*, use i8*. case OPAQUE: - return ir_builder->getInt8PtrTy(); + return llvm::Type::getInt8PtrTy(module->getContext()); default: LOG(FATAL) << "unsupported type " << element_type; } } -llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder) { - llvm::Type* result_type = - PrimitiveTypeToIrType(shape.element_type(), ir_builder); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module) { + llvm::Type* result_type = PrimitiveTypeToIrType(shape.element_type(), module); if (ShapeUtil::IsTuple(shape)) { // A tuple buffer is an array of pointers. result_type = llvm::ArrayType::get(result_type, shape.tuple_shapes_size()); @@ -193,10 +226,10 @@ namespace { // value down to zero). llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, std::vector* multi_index, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { const Shape& shape = literal.shape(); llvm::Type* ir_element_type = - llvm_ir::PrimitiveTypeToIrType(shape.element_type(), ir_builder); + llvm_ir::PrimitiveTypeToIrType(shape.element_type(), module); if (dimension_index == -1) { // Base case of the recursion. Index into the data field of the protobuf // with the multi index. @@ -234,6 +267,16 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, value = llvm::ConstantFP::get(ir_element_type, literal.Get(*multi_index)); break; + case C64: { + complex64 x = literal.Get(*multi_index); + value = llvm::ConstantStruct::get( + static_cast(ir_element_type), + llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), + x.real()), + llvm::ConstantFP::get(llvm_ir::PrimitiveTypeToIrType(F32, module), + x.imag())); + break; + } default: LOG(FATAL) << "unsupported type " << shape.element_type(); } @@ -252,8 +295,8 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, std::vector elements; for (int64 i = 0; i < shape.dimensions(dimension); ++i) { (*multi_index)[dimension] = i; - elements.push_back(LiteralToConstant(literal, dimension_index - 1, - multi_index, ir_builder)); + elements.push_back( + LiteralToConstant(literal, dimension_index - 1, multi_index, module)); } llvm::Type* element_type; @@ -275,11 +318,11 @@ llvm::Constant* LiteralToConstant(const Literal& literal, int64 dimension_index, } // namespace llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, - llvm::IRBuilder<>* ir_builder) { + llvm::Module* module) { std::vector multi_index(ShapeUtil::Rank(literal.shape()), 0); llvm::Constant* value = LiteralToConstant( literal, /*dimension_index=*/ShapeUtil::Rank(literal.shape()) - 1, - &multi_index, ir_builder); + &multi_index, module); return value; } @@ -329,17 +372,20 @@ LlvmIfData EmitIfThenElse(llvm::Value* condition, tensorflow::StringPiece name, ir_builder) : nullptr; - // There is no reason this function cannot work without a - // terminator, that is just a different case that has not been - // implemented yet. It is a different case because splitBasicBlock - // requires a terminator. - CHECK_NE(nullptr, if_data.if_block->getTerminator()); - if_data.after_block = if_data.if_block->splitBasicBlock( - ir_builder->GetInsertPoint(), - AsStringRef(tensorflow::strings::StrCat(name, "-after"))); - - // splitBasicBlock inserts an unconditional terminator that we have - // to remove as we want a conditional branch there. + // Add a terminator to the if block, if necessary. + if (if_data.if_block->getTerminator() == nullptr) { + ir_builder->SetInsertPoint(if_data.if_block); + if_data.after_block = CreateBasicBlock( + nullptr, tensorflow::strings::StrCat(name, "-after"), ir_builder); + ir_builder->CreateBr(if_data.after_block); + } else { + if_data.after_block = if_data.if_block->splitBasicBlock( + ir_builder->GetInsertPoint(), + AsStringRef(tensorflow::strings::StrCat(name, "-after"))); + } + + // Our basic block should now end with an unconditional branch. Remove it; + // we're going to replace it with a conditional branch. if_data.if_block->getTerminator()->eraseFromParent(); ir_builder->SetInsertPoint(if_data.if_block); @@ -373,7 +419,8 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, // comparison_result is i1, but the NVPTX codegen incorrectly lowers i1 // arrays. So we extend it to i8 so that it's addressable. return ir_builder->CreateZExt( - comparison_result, llvm_ir::PrimitiveTypeToIrType(PRED, ir_builder)); + comparison_result, + llvm_ir::PrimitiveTypeToIrType(PRED, ModuleFromIRBuilder(ir_builder))); } // Internal helper that is called from emitted code to log an int64 value with a @@ -395,13 +442,6 @@ void EmitLogging(const char* tag, llvm::Value* value, {ir_builder->getInt64(tensorflow::bit_cast(tag)), value}); } -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to) { - // TODO(b/62903316): TBAA metadata causes LLVM to miscompile generated code, - // most likely because the generated metadata is incorrect. Disable TBAA - // metadata while we resolve this. -} - void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment) { llvm::LLVMContext& context = load->getContext(); llvm::Type* int64_ty = llvm::Type::getInt64Ty(context); @@ -515,8 +555,9 @@ int64 ByteSizeOf(const Shape& shape, const llvm::DataLayout& data_layout) { llvm::FastMathFlags GetFastMathFlags(bool fast_math_enabled) { llvm::FastMathFlags flags; if (fast_math_enabled) { - // UnsafeAlgebra implies NoInfs, NoNaNs, NoSignedZeros, and AllowReciprocal. - flags.setUnsafeAlgebra(); + // Fast implies AllowReassoc, NoInfs, NoNaNs, NoSignedZeros, + // AllowReciprocal, AllowContract, and ApproxFunc. + flags.setFast(); } return flags; } @@ -579,5 +620,23 @@ std::map MergeMetadata( return result; } +Status DumpIRToDirectory(const string& directory_name, + const string& hlo_module_name, + const llvm::Module& llvm_module, bool optimized) { + string safe_file_name_base = SanitizeFileName(hlo_module_name); + string ir_file_name = tensorflow::io::JoinPath( + directory_name, + tensorflow::strings::StrCat("ir-", safe_file_name_base, "-", + optimized ? "with" : "no", "-opt.ll")); + + std::unique_ptr f; + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->RecursivelyCreateDir(directory_name)); + TF_RETURN_IF_ERROR( + tensorflow::Env::Default()->NewWritableFile(ir_file_name, &f)); + TF_RETURN_IF_ERROR(f->Append(DumpModuleToString(llvm_module))); + return f->Close(); +} + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h index ab8ac5e745dae3649b9d1cc62424aaaac50b6360..304192b58e9331c2544f973bf65299111122aea8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h +++ b/tensorflow/compiler/xla/service/llvm_ir/llvm_util.h @@ -127,11 +127,11 @@ llvm::Value* EmitBufferIndexingGEP(llvm::Value* array, int64 index, // Returns the LLVM type which represents the given XLA primitive type. llvm::Type* PrimitiveTypeToIrType(PrimitiveType element_type, - llvm::IRBuilder<>* ir_builder); + llvm::Module* module); // Returns the LLVM type which represents the given XLA shape. For example, // if "shape" is [5 x [10 x f32]], the function returns [5 x [10 x float]]. -llvm::Type* ShapeToIrType(const Shape& shape, llvm::IRBuilder<>* ir_builder); +llvm::Type* ShapeToIrType(const Shape& shape, llvm::Module* module); // Returns a value that represents a pointer to a global string constant that // encodes the shape as a serialized protobuf. @@ -149,7 +149,7 @@ StatusOr DecodeSelfDescribingShapeConstant(const void* shape_ptr, // Converts a given literal to an IR Constant. Literals have known constant // values at IR emission time. llvm::Constant* ConvertLiteralToIrConstant(const Literal& literal, - llvm::IRBuilder<>* ir_builder); + llvm::Module* module); // Inserts an allocate of the requested type at the entry point of the // function that the builder is currently building. The insert point @@ -227,12 +227,6 @@ llvm::Value* EmitComparison(llvm::CmpInst::Predicate predicate, void EmitLogging(const char* tag, llvm::Value* value, llvm::IRBuilder<>* ir_builder); -// Adds TBAA metadata to a load or store instruction using the given shape as -// it's type. The is_pointer_to parameter is used to indicate whether or not -// this instruction loads or stores a pointer to an array. -void SetTbaaForInstruction(llvm::Instruction* instruction, Shape shape, - bool is_pointer_to); - // Adds alignment metadata to a load instruction using the given alignment. // The alignment refers to the result of the load, not the load itself. void SetAlignmentMetadataForLoad(llvm::LoadInst* load, uint64_t alignment); @@ -273,6 +267,15 @@ std::map MergeMetadata( llvm::LLVMContext* context, const std::map& a, const std::map& b); +// Dumps out `llvm_module` to a file in the directory named `directory_name`, +// creating the directory if necessary. A sanitized version of +// `hlo_module_name` is incorporated into the file name. If `optimized` is true +// then a suffix of "-with-opt.ll" is used, else a suffix of "-no-opt.ll" is +// used. +Status DumpIRToDirectory(const string& directory_name, + const string& hlo_module_name, + const llvm::Module& llvm_module, bool optimized); + } // namespace llvm_ir } // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc index 8bba1776d19005292da48705df2436b6f30e0f2d..6fa4cd08c9e0ac30b83c0e2b49d98d930c2e15df 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/loop_emitter.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" -#include "tensorflow/compiler/xla/service/llvm_ir/ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.cc b/tensorflow/compiler/xla/service/llvm_ir/ops.cc index ac562e231c8f56184363d6e186c18847d67435ce..34899b7400464e4f4f97d301f35ed3b7b083bca1 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.cc +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.cc @@ -14,86 +14,167 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/xla/service/llvm_ir/ops.h" - -#include -#include -#include - -#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/types.h" -#include "tensorflow/compiler/xla/xla_data.pb.h" -#include "tensorflow/core/lib/strings/stringprintf.h" -#include "tensorflow/core/platform/logging.h" +#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h" namespace xla { namespace llvm_ir { -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder) { - CHECK(ShapeUtil::IsScalar(pred.GetShape())); - - llvm::LoadInst* pred_value = - ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); - llvm::Value* pred_cond = ir_builder->CreateICmpNE( - pred_value, - llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, ir_builder), 0), - "boolean_predicate"); - - VLOG(2) << "HandleSelect for tuple:"; - VLOG(2) << " pred_value: " << DumpToString(*pred_value); - VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); - - for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { - std::vector element_index = {ir_builder->getInt64(0), - ir_builder->getInt64(i)}; - llvm::Value* on_true_element_address = - ir_builder->CreateInBoundsGEP(on_true, element_index); - llvm::Value* on_true_element = ir_builder->CreateLoad( - on_true_element_address, - tensorflow::strings::Printf("on_true_element_%d", i).c_str()); - llvm::Value* on_false_element_address = - ir_builder->CreateInBoundsGEP(on_false, element_index); - llvm::Value* on_false_element = ir_builder->CreateLoad( - on_false_element_address, - tensorflow::strings::Printf("on_false_element_%d", i).c_str()); - - llvm::Value* output_element_address = - ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); - ir_builder->CreateStore( - ir_builder->CreateSelect( - pred_cond, on_true_element, on_false_element, - tensorflow::strings::Printf("select_output_element_%d", i).c_str()), - output_element_address); - } +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment) { + CHECK_EQ(HloOpcode::kDynamicUpdateSlice, dynamic_update_slice->opcode()); + const HloInstruction* operand = dynamic_update_slice->operand(0); + return assignment.HasTopLevelAllocation(dynamic_update_slice) && + assignment.HasTopLevelAllocation(operand) && + assignment.SharesTopLevelSlice(dynamic_update_slice, operand); } -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder) { - for (size_t i = 0; i < operands.size(); ++i) { - ir_builder->CreateStore( - ir_builder->CreatePointerCast(operands[i], - PrimitiveTypeToIrType(TUPLE, ir_builder)), - ir_builder->CreateInBoundsGEP( - tuple.GetBasePointer(), - {ir_builder->getInt64(0), ir_builder->getInt64(i)})); +// Shared implementation of EmitDynamicUpdateSliceInPlace and +// EmitFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitDynamicUpdateSliceInPlaceImpl( + const Shape& update_shape, const ElementGenerator& start_indices_generator, + ElementGenerator update_array_generator, const IrArray& output_array, + const gpu::LaunchDimensions* launch_dimensions, + tensorflow::StringPiece name, llvm::IRBuilder<>* ir_builder) { + const Shape& output_shape = output_array.GetShape(); + + // Read start indices from start_indices_generator. + const int64 rank = ShapeUtil::Rank(output_shape); + IrArray::Index start_index(rank); + for (int64 i = 0; i < rank; ++i) { + IrArray::Index dim_index({ir_builder->getInt64(i)}); + TF_ASSIGN_OR_RETURN(start_index[i], start_indices_generator(dim_index)); } + + auto loop_body_emitter = [&](const IrArray::Index& update_index) -> Status { + // Calculate output_index, where we'll write the value from update. For + // each dimension, + // + // output_index[dim] = (start_index[dim] + update_index[dim]) % dim_size. + // + IrArray::Index output_index(rank); + for (int64 i = 0; i < rank; ++i) { + llvm::Value* dim_size = llvm::ConstantInt::get( + update_index[i]->getType(), output_shape.dimensions(i)); + llvm::Value* start_index0 = ir_builder->CreateZExtOrBitCast( + start_index[i], update_index[i]->getType()); + output_index[i] = ir_builder->CreateURem( + ir_builder->CreateAdd(start_index0, update_index[i]), dim_size); + } + + // Do output[output_index] = update[update_index]. + TF_ASSIGN_OR_RETURN(llvm::Value * update_data, + update_array_generator(update_index)); + output_array.EmitWriteArrayElement(output_index, update_data, ir_builder); + return Status::OK(); + }; + + if (launch_dimensions != nullptr) { + return gpu::ParallelLoopEmitter(loop_body_emitter, update_shape, + *launch_dimensions, ir_builder) + .EmitLoop(name); + } + return LoopEmitter(loop_body_emitter, update_shape, ir_builder) + .EmitLoop(name); +} + +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder) { + VLOG(2) << "EmitDynamicUpdateSliceInPlace for " << name; + + // No need to use operand_arrays[0], the input array of the + // dynamic-update-slice, because we know it aliases the op's output. + IrArray update_array = operand_arrays[1]; + IrArray start_indices_array = operand_arrays[2]; + Shape output_shape = output_array.GetShape(); + Shape update_shape = update_array.GetShape(); + + ElementGenerator start_indices_generator = [&](const IrArray::Index& index) { + return start_indices_array.EmitReadArrayElement(index, ir_builder); + }; + ElementGenerator update_array_generator = [&](const IrArray::Index& index) { + return update_array.EmitReadArrayElement(index, ir_builder); + }; + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + output_array, /*launch_dimensions=*/nullptr, name, ir_builder); +} + +// Shared implementation for EmitFusedDynamicUpdateSliceInPlace and +// EmitParallelFusedDynamicUpdateSliceInPlace. +// +// Emits a sequential loop if launch_dimensions is null. +static Status EmitFusedDynamicUpdateSliceInPlaceImpl( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions* launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + VLOG(2) << "EmitFusedDynamicUpdateSliceInPlace for " + << fusion->ToShortString(); + + auto* dynamic_update_slice = fusion->fused_expression_root(); + + const auto* update = dynamic_update_slice->operand(1); + const auto* start_indices = dynamic_update_slice->operand(2); + Shape update_shape = update->shape(); + + // Our in-place dynamic-update-slice implementation emits a loop over + // update_shape. To emit a cache-friendly loop, we need to know that shape's + // layout. + // + // update_shape is inside a fusion node -- it's never materialized in memory + // and thus doesn't have a layout. In this case we use the layout of the + // fusion node for iteration, since that corresponds to the order in memory of + // the buffer we'll be writing to. + // + // (This isn't necessarily optimal; in some cases it might be faster to peek + // through the chain of ops that gives us the update operand and use the + // layout of its source buffer(s). But this is no worse than we do with + // fusion elsewhere.) + TF_RETURN_IF_ERROR( + LayoutUtil::CopyLayoutBetweenShapes(fusion->shape(), &update_shape)); + + // Create element generators for update and start_indices. + FusedIrEmitter fused_emitter(fusion_operand_arrays, elemental_emitter); + TF_RETURN_IF_ERROR(dynamic_update_slice->Accept(&fused_emitter)); + ElementGenerator update_array_generator = fused_emitter.GetGenerator(update); + ElementGenerator start_indices_generator = + fused_emitter.GetGenerator(start_indices); + + return EmitDynamicUpdateSliceInPlaceImpl( + update_shape, start_indices_generator, update_array_generator, + fusion_output_array, launch_dimensions, IrName(fusion), ir_builder); +} + +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + /*launch_dimensions=*/nullptr, ir_builder); } -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder) { - llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( - operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); - llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); - SetTbaaForInstruction(src_buffer, target_shape, /*is_pointer_to=*/true); - SetAlignmentMetadataForLoad(src_buffer, alignment); - llvm::Type* element_type = ShapeToIrType(target_shape, ir_builder); - llvm::Value* ret_val = - ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); - return ret_val; +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder) { + return EmitFusedDynamicUpdateSliceInPlaceImpl( + fusion, fusion_operand_arrays, fusion_output_array, elemental_emitter, + &launch_dimensions, ir_builder); } } // namespace llvm_ir diff --git a/tensorflow/compiler/xla/service/llvm_ir/ops.h b/tensorflow/compiler/xla/service/llvm_ir/ops.h index 4e1d9d1080b3a5c8d8a09145f68bcff9d329c929..11e84d9cb5defbcb87a8f696d56c139686c960d8 100644 --- a/tensorflow/compiler/xla/service/llvm_ir/ops.h +++ b/tensorflow/compiler/xla/service/llvm_ir/ops.h @@ -13,67 +13,68 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ -#include "llvm/IR/IRBuilder.h" -#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/buffer_assignment.h" +#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h" +#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" -#include "tensorflow/core/lib/gtl/array_slice.h" -#include "tensorflow/core/platform/types.h" + +// Utilities related to emitting LLVM IR for various HLO ops. namespace xla { namespace llvm_ir { -// Selection among tuples is special in how it's lowered, because a tuple is not -// an HLO array. -// -// tuple_on_true tuple_on_false -// | | -// V V -// ------------------------ ------------------------ -// | address of element 0 | | address of element 0 | -// |----------------------| |----------------------| -// | address of element 1 | | address of element 1 | -// |----------------------| |----------------------| -// | address of element 2 | | address of element 2 | -// ------------------------ ------------------------ -// \ / -// \ / -// ---------- -// pred ---------> | select | -// ---------- -// | -// V -// output ----> ------------------------ -// | address of element 0 | -// |----------------------| -// | address of element 1 | -// |----------------------| -// | address of element 2 | -// ------------------------ +// Checks if we can emit code for the given DynamicUpdateSlice node that updates +// its input in place. Returns true if the dynamic-update-slice's +// array-to-be-updated and output share the same BufferAllocation::Slice. // -// Only the addresses are copied to the output. For each element, we emit a copy -// of the address from the corresponding element in either -// tuple_on_true or tuple_on_false: -// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] -void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, - llvm::Value* on_false, llvm::IRBuilder<>* ir_builder); +// dynamic_update_slice must be a DynamicUpdateSlice op. +bool CanUpdateDynamicSliceInPlace(HloInstruction* dynamic_update_slice, + const BufferAssignment& assignment); + +// Checks if the given fusion node is amenable to being implemented by +// EmitFusedDynamicUpdateSliceInPlace. +inline bool CanEmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, const BufferAssignment& assignment) { + CHECK_EQ(fusion->opcode(), HloOpcode::kFusion); + return fusion->fusion_kind() == HloInstruction::FusionKind::kLoop && + fusion->fused_expression_root()->opcode() == + HloOpcode::kDynamicUpdateSlice && + CanUpdateDynamicSliceInPlace(fusion->fused_expression_root(), + assignment); +} + +// Emits IR for running the given dynamic-update-slice op in-place -- that is, +// where the input and output buffers share the same slice, so we can simply +// modify the input/output buffer without touching any of the other elements. +Status EmitDynamicUpdateSliceInPlace( + tensorflow::gtl::ArraySlice operand_arrays, + const IrArray& output_array, tensorflow::StringPiece name, + llvm::IRBuilder<>* ir_builder); + +// Given a loop-fusion node whose root is a dynamic-update-slice op whose +// array-to-be-updated and output share the same buffer slice, emits +// (sequential) code for a fusion node that does the dynamic-update-slice in +// place. +Status EmitFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. -void EmitTuple(IrArray tuple, - tensorflow::gtl::ArraySlice operands, - llvm::IRBuilder<>* ir_builder); +// Same as EmitFusedDynamicUpdateSliceInPlace, except emits a parallel loop with +// the given launch dimensions. +Status EmitParallelFusedDynamicUpdateSliceInPlace( + HloInstruction* fusion, + tensorflow::gtl::ArraySlice fusion_operand_arrays, + const IrArray& fusion_output_array, ElementalIrEmitter* elemental_emitter, + const gpu::LaunchDimensions& launch_dimensions, + llvm::IRBuilder<>* ir_builder); -// A tuple is an array of pointers, one for each operand. Each pointer points to -// the output buffer of its corresponding operand. A GetTupleElement instruction -// forwards the pointer to underlying tuple element buffer at the given index. -// Returns an llvm value representing a pointer to the tuple element buffer. -llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, - int alignment, llvm::Value* operand, - llvm::IRBuilder<>* ir_builder); } // namespace llvm_ir } // namespace xla -#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_OPS_H_ diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..3a21eda35757aa706565ee4a5286eee1acea117b --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.cc @@ -0,0 +1,110 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h" + +#include +#include +#include + +#include "llvm/IR/Instructions.h" +#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/logging.h" + +namespace xla { +namespace llvm_ir { + +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, + llvm::Module* module) { + CHECK(ShapeUtil::IsScalar(pred.GetShape())); + + llvm::LoadInst* pred_value = + ir_builder->CreateLoad(pred.GetBasePointer(), "load_predicate_value"); + llvm::Value* pred_cond = ir_builder->CreateICmpNE( + pred_value, + llvm::ConstantInt::get(PrimitiveTypeToIrType(PRED, module), 0), + "boolean_predicate"); + + VLOG(2) << "HandleSelect for tuple:"; + VLOG(2) << " pred_value: " << DumpToString(*pred_value); + VLOG(2) << " pred_cond: " << DumpToString(*pred_cond); + + for (int i = 0; i < ShapeUtil::TupleElementCount(select.GetShape()); ++i) { + std::vector element_index = {ir_builder->getInt64(0), + ir_builder->getInt64(i)}; + llvm::Value* on_true_element_address = + ir_builder->CreateInBoundsGEP(on_true, element_index); + llvm::Value* on_true_element = ir_builder->CreateLoad( + on_true_element_address, + tensorflow::strings::Printf("on_true_element_%d", i).c_str()); + llvm::Value* on_false_element_address = + ir_builder->CreateInBoundsGEP(on_false, element_index); + llvm::Value* on_false_element = ir_builder->CreateLoad( + on_false_element_address, + tensorflow::strings::Printf("on_false_element_%d", i).c_str()); + + llvm::Value* output_element_address = + ir_builder->CreateInBoundsGEP(select.GetBasePointer(), element_index); + ir_builder->CreateStore( + ir_builder->CreateSelect( + pred_cond, on_true_element, on_false_element, + tensorflow::strings::Printf("select_output_element_%d", i).c_str()), + output_element_address); + } +} + +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder, llvm::Module* module) { + for (size_t i = 0; i < operands.size(); ++i) { + auto* store = ir_builder->CreateStore( + ir_builder->CreatePointerCast(operands[i], + PrimitiveTypeToIrType(TUPLE, module)), + ir_builder->CreateInBoundsGEP( + tuple.GetBasePointer(), + {ir_builder->getInt64(0), ir_builder->getInt64(i)})); + tuple.AnnotateLoadStoreInstructionWithMetadata(store); + } +} + +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder, + llvm::Module* module) { + llvm::Value* element_ptr = ir_builder->CreateInBoundsGEP( + operand, {ir_builder->getInt64(0), ir_builder->getInt64(index)}); + llvm::LoadInst* src_buffer = ir_builder->CreateLoad(element_ptr); + + // Mark the loaded pointer as dereferenceable if we know its shape. + if (!ShapeUtil::IsOpaque(target_shape)) { + SetDereferenceableMetadataForLoad( + src_buffer, + ByteSizeOf(target_shape, src_buffer->getModule()->getDataLayout())); + } + SetAlignmentMetadataForLoad(src_buffer, alignment); + + llvm::Type* element_type = ShapeToIrType(target_shape, module); + llvm::Value* ret_val = + ir_builder->CreateBitCast(src_buffer, element_type->getPointerTo()); + return ret_val; +} + +} // namespace llvm_ir +} // namespace xla diff --git a/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h new file mode 100644 index 0000000000000000000000000000000000000000..dbf9a140068b60505f6798360438f709bfd3feba --- /dev/null +++ b/tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h @@ -0,0 +1,83 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ + +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Value.h" +#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" +#include "tensorflow/core/lib/gtl/array_slice.h" +#include "tensorflow/core/platform/types.h" + +// Utilities for emitting LLVM IR related to HLO tuples. + +namespace xla { +namespace llvm_ir { + +// Selection among tuples is special in how it's lowered, because a tuple is not +// an HLO array. +// +// tuple_on_true tuple_on_false +// | | +// V V +// ------------------------ ------------------------ +// | address of element 0 | | address of element 0 | +// |----------------------| |----------------------| +// | address of element 1 | | address of element 1 | +// |----------------------| |----------------------| +// | address of element 2 | | address of element 2 | +// ------------------------ ------------------------ +// \ / +// \ / +// ---------- +// pred ---------> | select | +// ---------- +// | +// V +// output ----> ------------------------ +// | address of element 0 | +// |----------------------| +// | address of element 1 | +// |----------------------| +// | address of element 2 | +// ------------------------ +// +// Only the addresses are copied to the output. For each element, we emit a copy +// of the address from the corresponding element in either +// tuple_on_true or tuple_on_false: +// output[i] = pred ? tuple_on_true[i] : tuple_on_false[i] +void EmitTupleSelect(IrArray select, IrArray pred, llvm::Value* on_true, + llvm::Value* on_false, llvm::IRBuilder<>* ir_builder, + llvm::Module* module); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. +void EmitTuple(IrArray tuple, + tensorflow::gtl::ArraySlice operands, + llvm::IRBuilder<>* ir_builder, llvm::Module* module); + +// A tuple is an array of pointers, one for each operand. Each pointer points to +// the output buffer of its corresponding operand. A GetTupleElement instruction +// forwards the pointer to underlying tuple element buffer at the given index. +// Returns an llvm value representing a pointer to the tuple element buffer. +llvm::Value* EmitGetTupleElement(const Shape& target_shape, int64 index, + int alignment, llvm::Value* operand, + llvm::IRBuilder<>* ir_builder, + llvm::Module* module); +} // namespace llvm_ir +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_TUPLE_OPS_H_ diff --git a/tensorflow/compiler/xla/service/local_service.cc b/tensorflow/compiler/xla/service/local_service.cc index 3235081f83f53e2850efa2c6ccd221318fa0c58b..d4d35da9d636e6e204f36850e7987327ab258696 100644 --- a/tensorflow/compiler/xla/service/local_service.cc +++ b/tensorflow/compiler/xla/service/local_service.cc @@ -91,7 +91,7 @@ int64 RequiredSpace(const Shape& shape, bool allocate_space_for_deep_copy, StatusOr> LocalService::CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, bool has_hybrid_result) { + const Shape* result_layout, int device_ordinal) { TF_ASSIGN_OR_RETURN(UserComputation * user_computation, computation_tracker_.Resolve(computation)); VersionedComputationHandle versioned_handle = @@ -133,8 +133,7 @@ StatusOr> LocalService::CompileExecutable( } TF_ASSIGN_OR_RETURN( std::unique_ptr module_config, - CreateModuleConfig(*program_shape, argument_layouts, &execution_options, - has_hybrid_result)); + CreateModuleConfig(*program_shape, argument_layouts, &execution_options)); TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, execute_backend_->stream_executor(device_ordinal)); diff --git a/tensorflow/compiler/xla/service/local_service.h b/tensorflow/compiler/xla/service/local_service.h index f2bfb960f4307d12556337f76cfd6ea7a38b6e20..52c4346385eb663baa6e7579d7b3883ba084205b 100644 --- a/tensorflow/compiler/xla/service/local_service.h +++ b/tensorflow/compiler/xla/service/local_service.h @@ -45,7 +45,7 @@ class LocalService : public Service { StatusOr> CompileExecutable( const ComputationHandle& computation, const tensorflow::gtl::ArraySlice argument_layouts, - const Shape* result_layout, int device_ordinal, bool has_hybrid_result); + const Shape* result_layout, int device_ordinal); private: explicit LocalService(const ServiceOptions& options, diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 8041d74baa72905ea1f81e6d67c67fe6430640a1..b92017c6cbc43d78ab4e5b32f25f5980b8d4ae56 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -41,12 +41,9 @@ Status LogicalBufferAnalysis::Analyze() { // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) // fusion computations, and we don't want to try to assign buffers to those. - for (auto& computation : module_->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -89,8 +86,7 @@ Status LogicalBufferAnalysis::DefaultAction(HloInstruction* hlo_instruction) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleGetTupleElement( - HloInstruction* get_tuple_element, HloInstruction* operand) { +Status LogicalBufferAnalysis::HandleGetTupleElement(HloInstruction*) { // GetTupleElement does not create buffers. return Status::OK(); } @@ -102,24 +98,19 @@ Status LogicalBufferAnalysis::HandleCopy(HloInstruction* copy) { return Status::OK(); } -Status LogicalBufferAnalysis::HandleBitcast(HloInstruction* bitcast) { +Status LogicalBufferAnalysis::HandleBitcast(HloInstruction*) { // A kBitcast instruction aliases its operand. That is, the buffer of its // result *is* the buffer of its operand. return Status::OK(); } -Status LogicalBufferAnalysis::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status LogicalBufferAnalysis::HandleTuple(HloInstruction* tuple) { // A Tuple instruction only creates the top-level buffer. NewLogicalBuffer(tuple, /*index=*/{}); return Status::OK(); } -Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select, - HloInstruction* /*pred*/, - HloInstruction* on_true, - HloInstruction* on_false) { +Status LogicalBufferAnalysis::HandleSelect(HloInstruction* select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. NewLogicalBuffer(select, /*index=*/{}); diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.h b/tensorflow/compiler/xla/service/logical_buffer_analysis.h index de9fe1b0a4ed3f6f8c466050520a9c4889793c62..a82e83ec5c3d2b0e011d85f3d03bea8fca870154 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.h +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.h @@ -56,16 +56,11 @@ class LogicalBufferAnalysis : public DfsHloVisitorWithDefault { void NewLogicalBuffer(HloInstruction* instruction, const ShapeIndex& index); Status DefaultAction(HloInstruction* hlo_instruction) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; + Status HandleSelect(HloInstruction* select) override; // A map from the buffer ID to the logical buffer std::vector> logical_buffers_; diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc index 069f85af721228c8f5d40cf243eea7f1e5173c62..a0d08c288dbcc45e83a36ce7b094b04a9dbae532 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.cc +++ b/tensorflow/compiler/xla/service/name_uniquer.cc @@ -23,7 +23,24 @@ namespace xla { string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { string root = prefix.empty() ? "name" : prefix.ToString(); - int* count = &(generated_names_[root]); + + // Strip away numeric suffix (if any). Only recognize separator if it is in + // the middle of the name. + size_t separator_index = root.rfind(separator_); + if (separator_index != string::npos && (separator_index > 0) && + (separator_index < root.size() - 1)) { + string after_suffix = root.substr(separator_index + 1); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_suffix, &numeric_suffix)) { + // Remove numeric suffix from root. + root = root.substr(0, separator_index); + // Update count to at least the numeric suffix value to avoid future + // colisions with this name. + generated_names_[root] = std::max(generated_names_[root], numeric_suffix); + } + } + + int64* count = &(generated_names_[root]); if (*count == 0) { *count = 1; return root; @@ -31,9 +48,6 @@ string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) { tensorflow::strings::StrAppend(&root, separator_, *count); // Increment lookup under old 'root' name. (*count)++; - // Initialize count under new 'root' name. - count = &(generated_names_[root]); - *count = 1; return root; } } diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h index b0944adbc1d98fd88c550cc8b53cf399e43535e6..ed379b52258463b960dea788721c2c4325ef0260 100644 --- a/tensorflow/compiler/xla/service/name_uniquer.h +++ b/tensorflow/compiler/xla/service/name_uniquer.h @@ -43,7 +43,7 @@ class NameUniquer { // Map from name prefix to the number of names generated using that prefix // so far. - std::unordered_map generated_names_; + std::unordered_map generated_names_; TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer); }; diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f0747a6e2175a968d8f3661ac51512009e86f29 --- /dev/null +++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/name_uniquer.h" + +#include +#include +#include + +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace { + +class NameUniquerTest : public ::testing::Test {}; + +TEST_F(NameUniquerTest, SimpleUniquer) { + NameUniquer uniquer; + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo__2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo__3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar__1", uniquer.GetUniqueName("bar")); + EXPECT_EQ("qux", uniquer.GetUniqueName("qux")); +} + +TEST_F(NameUniquerTest, DifferentSeparator) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.1", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.2", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar")); + EXPECT_EQ("foo.3", uniquer.GetUniqueName("foo")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar")); +} + +TEST_F(NameUniquerTest, NumericSuffixes) { + NameUniquer uniquer("."); + + EXPECT_EQ("foo", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54")); + EXPECT_EQ("foo.55", uniquer.GetUniqueName("foo")); + EXPECT_EQ("foo.55.1", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("foo.55.2", uniquer.GetUniqueName("foo.55.1")); + EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000")); + EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000")); + EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1")); + + // Separator is only recognized in the middle of the prefix. + EXPECT_EQ(".10", uniquer.GetUniqueName(".10")); + EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10")); + EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar.")); + EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar.")); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/platform_util.cc b/tensorflow/compiler/xla/service/platform_util.cc index 4f915a0c2eeaca0fe077a907571c8379992185eb..3a1818de82d3fd305e2c6b3bd1f2cf8125806a75 100644 --- a/tensorflow/compiler/xla/service/platform_util.cc +++ b/tensorflow/compiler/xla/service/platform_util.cc @@ -84,15 +84,6 @@ PlatformUtil::GetSupportedPlatforms() { return NotFound("no platforms found"); } else if (platforms.size() == 1) { return platforms[0]; - } else if (platforms.size() == 2) { - // In the service we always link the cpu backend for ComputeConstant. So if - // one of the two platforms is CPU then pick the other (non-cpu) platform as - // the default. - if (platforms[0]->id() == se::host::kHostPlatformId) { - return platforms[1]; - } else if (platforms[1]->id() == se::host::kHostPlatformId) { - return platforms[0]; - } } // Multiple platforms present and we can't pick a reasonable default. diff --git a/tensorflow/compiler/xla/service/platform_util.h b/tensorflow/compiler/xla/service/platform_util.h index fe0281a69a441b5462470e88bd3ad73784a8da35..eac573703085aca2801885cd9abbe0022f1c029e 100644 --- a/tensorflow/compiler/xla/service/platform_util.h +++ b/tensorflow/compiler/xla/service/platform_util.h @@ -36,12 +36,7 @@ class PlatformUtil { // Convenience function which returns the default supported platform. If // exactly one supported platform is present, then this platform is the - // default platform. If exactly two supported platforms are present and one - // platform is CPU (host) then the non-CPU platform is default. This logic is - // used because the XLA service always links in the CPU backend to run - // ComputeConstant, so if exactly one other platform is linked in, we assume - // the intent is to execute on that non-CPU platform. If none of these - // conditions are met the function returns an error. + // default platform. Otherwise returns an error. static StatusOr GetDefaultPlatform(); // Returns a vector of StreamExecutors for the given platform. The vector is diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc index 8275531111ce10e05d81a77c739757a649f97a1c..e2c07e38271df8b8875b2c9291f18ba41a9e6acd 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.cc @@ -29,27 +29,27 @@ std::vector ReducePrecisionInsertion::instructions_to_modify( case HloReducePrecisionOptions::OP_INPUTS: case HloReducePrecisionOptions::OP_OUTPUTS: case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); - if (instruction_filter_function_(instruction.get())) { - instruction_list.push_back(instruction.get()); + if (instruction_filter_function_(instruction)) { + instruction_list.push_back(instruction); } } break; case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { VLOG(4) << "Visited instruction: " << instruction->ToString(); if (instruction->opcode() != HloOpcode::kFusion) { continue; } - for (auto& fused_instruction : + for (auto* fused_instruction : instruction->fused_instructions_computation()->instructions()) { VLOG(4) << "Checking sub-instruction: " << fused_instruction->ToString(); - if (instruction_filter_function_(fused_instruction.get())) { - instruction_list.push_back(instruction.get()); + if (instruction_filter_function_(fused_instruction)) { + instruction_list.push_back(instruction); break; } } @@ -96,8 +96,7 @@ StatusOr ReducePrecisionInsertion::insert_after( HloInstruction* reduced = instruction->parent()->AddInstruction( HloInstruction::CreateReducePrecision(instruction->shape(), instruction, exponent_bits_, mantissa_bits_)); - TF_RETURN_IF_ERROR( - instruction->parent()->ReplaceUsesOfInstruction(instruction, reduced)); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(reduced)); return true; } @@ -198,24 +197,20 @@ StatusOr ReducePrecisionInsertion::Run(HloModule* module) { bool changed = false; VLOG(1) << "Running ReducePrecisionInsertion pass on " << module->name(); - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - + for (auto* computation : module->MakeNonfusionComputations()) { StatusOr computation_changed; switch (location_) { case HloReducePrecisionOptions::OP_INPUTS: case HloReducePrecisionOptions::FUSION_INPUTS_BY_CONTENT: computation_changed = ReducePrecisionInsertion::insert_on_inputs( - instructions_to_modify(computation.get())); + instructions_to_modify(computation)); break; case HloReducePrecisionOptions::FUSION_OUTPUTS_BY_CONTENT: case HloReducePrecisionOptions::OP_OUTPUTS: case HloReducePrecisionOptions::UNFUSED_OP_OUTPUTS: computation_changed = ReducePrecisionInsertion::insert_on_outputs( - instructions_to_modify(computation.get())); + instructions_to_modify(computation)); break; default: break; diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc index a62560be59964a50ba1db40301fe6a94216997a5..69e4b534bd8e3aeab8b729f3e594a10b4368f15f 100644 --- a/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc +++ b/tensorflow/compiler/xla/service/reduce_precision_insertion_test.cc @@ -381,7 +381,7 @@ TEST_F(ReducePrecisionInsertionTest, IgnoreOpsInsideFusionNode) { // Manually fuse the kCos operation into a fusion operation. HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion( shape, HloInstruction::FusionKind::kLoop, y)); - EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z)); + EXPECT_IS_OK(y->ReplaceAllUsesWith(z)); EXPECT_IS_OK(computation->RemoveInstruction(y)); // Confirm expected graph before adding reduce-precision ops. @@ -417,7 +417,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInHeadOfFusionNode) { // Manually fuse the kCos operation into a fusion operation. HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion( shape, HloInstruction::FusionKind::kLoop, y)); - EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z)); + EXPECT_IS_OK(y->ReplaceAllUsesWith(z)); EXPECT_IS_OK(computation->RemoveInstruction(y)); // Confirm expected graph before adding reduce-precision ops. @@ -464,7 +464,7 @@ TEST_F(ReducePrecisionInsertionTest, OpGetsInsertedInTailOfFusionNode) { // Manually fuse the kCos operation into a fusion operation. HloInstruction* z = computation->AddInstruction(HloInstruction::CreateFusion( shape, HloInstruction::FusionKind::kLoop, y)); - EXPECT_IS_OK(computation->ReplaceUsesOfInstruction(y, z)); + EXPECT_IS_OK(y->ReplaceAllUsesWith(z)); EXPECT_IS_OK(computation->RemoveInstruction(y)); // Confirm expected graph before adding reduce-precision ops. diff --git a/tensorflow/compiler/xla/service/reshape_mover.cc b/tensorflow/compiler/xla/service/reshape_mover.cc index a480236cebd9b020436b495df24a25421cebf174..0fb90230f2f39a841973361f63d17af579a1342b 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.cc +++ b/tensorflow/compiler/xla/service/reshape_mover.cc @@ -48,23 +48,28 @@ namespace xla { namespace { -// Checks if an instruction can change its shape simply by adjusting metadata. -// This is the case if it is: -// -// - an instruction does not have any producers like Constants -// or Rng instruction, or is a scalar. -// -// Or -// -// - an reshape/transpose instruction with an operand that can trivially change -// its shape. -bool InstructionCanTriviallyChangeShape(const HloInstruction* instruction) { - // Reshape/Transposes are only trivial if their operand is trivial. - if (instruction->opcode() == HloOpcode::kReshape || - instruction->opcode() == HloOpcode::kTranspose) { - CHECK_EQ(instruction->operand_count(), 1); - return InstructionCanTriviallyChangeShape(instruction->operand(0)); - } +bool IsReshapeOrTranspose(const HloInstruction* instruction) { + return instruction->opcode() == HloOpcode::kReshape || + instruction->opcode() == HloOpcode::kTranspose; +} + +// Returns true iff `instruction` can change its shape simply by adjusting +// metadata. +bool CanTriviallyChangeShape(const HloInstruction* instruction) { + // NOTE: Technically a sequence of reshape(reshape(constant)) is also + // trivially reshapable, so we might be tempted to simply recurse if + // IsReshapeOrTranspose(instruction)==true. + // + // But it's not that simple. E.g. reshape(reshape(rng)) is only trivially + // reshapable if *all* instructions in the chain have user_count == 1. And + // reshape(scalar) isn't trivial at all if the reshape itself isn't scalar; we + // rely on implicit scalar broadcast for scalars to be trivial. In addition, + // these cases make it harder to maintain correctness of the UpdateOperand + // logic below. + // + // So don't handle these chains, unless you update the tests and code to deal + // with these properly. One idea is to add a pass immediately beforehand that + // collapses trivial runs of reshapes / transposes. // Scalars can operate with any shape. if (ShapeUtil::IsScalar(instruction->shape())) { @@ -93,9 +98,8 @@ HloInstruction* FirstNonScalarAndNonTrivialReshapeOperand( const HloInstruction* hlo) { for (HloInstruction* operand : hlo->operands()) { if (!ShapeUtil::IsScalar(operand->shape()) && - ((operand->opcode() == HloOpcode::kReshape || - operand->opcode() == HloOpcode::kTranspose) && - !InstructionCanTriviallyChangeShape(operand->operand(0)))) { + IsReshapeOrTranspose(operand) && + !CanTriviallyChangeShape(operand->operand(0))) { VLOG(5) << "Found first non-scalar and non-trivial reshape operand of " << hlo->ToStringNoMetadata() << ":\n\t" << operand->ToStringNoMetadata(); @@ -122,28 +126,15 @@ bool AreEquivalentReshapes(const HloInstruction* a, const HloInstruction* b) { } } -// Returns true if an elementwise operation has all operands that can easily -// change shape. Operands can easily change shape if they are all -// reshapes/transposes to and from the same shape. Additionally, operands like -// constant, rng, and any scalar change shape with only an adjustment of -// metadata. -bool IsElementwiseOfEquivalentReshapesOrTransposes( - const HloInstruction* instruction) { - const auto& operands = instruction->operands(); - HloInstruction* first_reshape_operand = - FirstNonScalarAndNonTrivialReshapeOperand(instruction); - // If there are no non-trivial reshapes or transposes, then there is nothing - // to sink below the elementwise operation. - if (!first_reshape_operand) { - return false; - } - VLOG(3) << "** Checking whether instruction is an elementwise operation of " - "equivalent reshapes/transposes: " +// Returns true if all operands of `instruction` can easily change shape. +// Operands can easily change shape if they are all reshapes/transposes to and +// from the same shape. Additionally, operands like constant, rng, and any +// scalar change shape with only an adjustment of metadata. +bool AllOperandsHaveEasyShapeChanges( + const HloInstruction* instruction, + const HloInstruction* first_reshape_operand) { + VLOG(3) << "** Checking whether all operands have easy shape changes: " << instruction->ToStringNoMetadata(); - bool result = (instruction->user_count() > 0 || - instruction == instruction->parent()->root_instruction()) && - instruction->IsElementwise() && !operands.empty(); - // Check whether all operands: // 0. Have the same dimensions as the output -- if not, it may be // implicitly broadcast, which can confound the movement's @@ -155,66 +146,117 @@ bool IsElementwiseOfEquivalentReshapesOrTransposes( // or // 2. Are one of kConstant, kRng, and scalars that can change shape // trivially, - if (result) { - for (auto& operand : operands) { - if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { - VLOG(5) << "Operand shape differs from output shape; may be " - "implicitly broadcast, so preventing " - "movement\n\toperand: " - << operand->ToStringNoMetadata() - << "\n\tinstruction: " << instruction->ToStringNoMetadata(); - result = false; - break; - } - - if (AreEquivalentReshapes(first_reshape_operand, operand)) { - VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " - << first_reshape_operand->ToStringNoMetadata() - << "\n\toperand: " << operand->ToStringNoMetadata(); - continue; - } + for (const HloInstruction* operand : instruction->operands()) { + if (!ShapeUtil::SameDimensions(operand->shape(), instruction->shape())) { + VLOG(5) << "Operand shape differs from output shape; may be " + "implicitly broadcast, so preventing " + "movement\n\toperand: " + << operand->ToStringNoMetadata() + << "\n\tinstruction: " << instruction->ToStringNoMetadata(); + return false; + } - if (InstructionCanTriviallyChangeShape(operand)) { - VLOG(5) << "Operand can trivially change shape: " - << operand->ToStringNoMetadata(); - continue; - } + if (AreEquivalentReshapes(first_reshape_operand, operand)) { + VLOG(5) << "Are equivalent reshapes:\n\tfirst_reshape_operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\toperand: " << operand->ToStringNoMetadata(); + continue; + } - // TODO(someone): Look into supporting general ops for the operands as - // well. - VLOG(5) << "Operand is neither equalivant to the first Reshape operand" - "nor can trivially change shape: " + if (CanTriviallyChangeShape(operand)) { + VLOG(5) << "Operand can trivially change shape: " << operand->ToStringNoMetadata(); - result = false; - break; + continue; } + + // TODO(someone): Look into supporting general ops for the operands as + // well. + VLOG(5) << "Operand is neither equalivant to the first Reshape operand" + "nor can trivially change shape: " + << operand->ToStringNoMetadata(); + return false; } - VLOG(3) << "ElementwiseOfEquivalentReshapesOrTransposes result for " - << instruction->ToStringNoMetadata() << ": " << result; - return result; + VLOG(3) << "All operands have easy shape changes: " + << instruction->ToStringNoMetadata(); + return true; +} + +// This function is called once we've decided to sink reshape/transpose operands +// across an instruction. It returns an updated `operand` with a shape that +// plays nicely with `new_operand_shape`; either it has the same shape (of the +// correct type), or it is a scalar that may be implicitly broadcast. +HloInstruction* UpdateOperand(HloComputation* computation, + const HloInstruction* first_reshape_operand, + const Shape& new_operand_shape, + HloInstruction* operand) { + const PrimitiveType element_type = operand->shape().element_type(); + const Shape new_shape = + ShapeUtil::ChangeElementType(new_operand_shape, element_type); + + switch (operand->opcode()) { + case HloOpcode::kConstant: { + if (first_reshape_operand->opcode() == HloOpcode::kReshape) { + VLOG(5) << "Adding reshape to kConstant operand"; + return computation->AddInstruction( + HloInstruction::CreateReshape(new_shape, operand)); + } else { + CHECK(first_reshape_operand->opcode() == HloOpcode::kTranspose); + VLOG(5) << "Adding transpose to kConstant operand"; + std::vector inverse_permutation = + InversePermutation(first_reshape_operand->dimensions()); + return computation->AddInstruction(HloInstruction::CreateTranspose( + new_shape, operand, inverse_permutation)); + } + } + case HloOpcode::kRng: { + CHECK_EQ(operand->user_count(), 1); + VLOG(5) << "Cloning kRng operand with new shape"; + return computation->AddInstruction( + operand->CloneWithNewOperands(new_shape, operand->operands())); + } + case HloOpcode::kReshape: + case HloOpcode::kTranspose: { + VLOG(5) << "Using existing operand of kReshape or kTranspose"; + return operand->mutable_operand(0); + } + default: + LOG(FATAL) << "Unexpected operand opcode during update: " << operand; + } } // Try to sink any reshape or transpose operands of `instruction` across it. We // do so if `instruction` is elementwise and all operands are either equivalent -// reshapes/transposes or are trivially reshapable. Note that no move is -// performend if there is no nontrivial reshapes/transposes. +// reshapes/transposes or are trivially reshapable. StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction* instruction) { - if (!IsElementwiseOfEquivalentReshapesOrTransposes(instruction)) { + // Only perform sinks for live elementwise instructions with operands. + const bool is_dead = instruction->user_count() == 0 && + instruction != computation->root_instruction(); + if (!instruction->IsElementwise() || instruction->operands().empty() || + is_dead) { return false; } - HloInstruction* old_reshape = + // Only perform sinks if there are any nontrivial reshape/transpose operands. + const HloInstruction* first_reshape_operand = FirstNonScalarAndNonTrivialReshapeOperand(instruction); - TF_RET_CHECK(old_reshape != nullptr); - Shape new_elementwise_shape = old_reshape->operand(0)->shape(); + if (!first_reshape_operand) { + return false; + } - VLOG(3) << "** Trying to sink reshape or transpose: " - << instruction->ToStringNoMetadata() - << "\n\told reshape: " << old_reshape->ToStringNoMetadata() - << "\n\tnew elementwise shape: " - << ShapeUtil::HumanString(new_elementwise_shape); + // Only perform sinks if all operands can easily change shape. + if (!AllOperandsHaveEasyShapeChanges(instruction, first_reshape_operand)) { + return false; + } + + // At this point we've decided to sink reshape/transpose operands. + const Shape& new_operand_shape = first_reshape_operand->operand(0)->shape(); + VLOG(3) << "** Sinking reshape or transpose: " + << instruction->ToStringNoMetadata() << "\n\tfirst reshape operand: " + << first_reshape_operand->ToStringNoMetadata() + << "\n\tnew operand shape: " + << ShapeUtil::HumanString(new_operand_shape); auto operands = instruction->operands(); for (size_t i = 0; i < operands.size(); ++i) { @@ -224,55 +266,19 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, if (ShapeUtil::IsScalar(operands[i]->shape())) { continue; } - PrimitiveType element_type = operands[i]->shape().element_type(); - switch (operands[i]->opcode()) { - case HloOpcode::kConstant: { - if (old_reshape->opcode() == HloOpcode::kReshape) { - VLOG(3) << "Creating reshape for kConstant operand " << i << ": " - << operands[i]->ToStringNoMetadata(); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateReshape( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i])); - } else { - TF_RET_CHECK(old_reshape->opcode() == HloOpcode::kTranspose); - std::vector inverse_permutation = - InversePermutation(old_reshape->dimensions()); - operands[i] = instruction->parent()->AddInstruction( - HloInstruction::CreateTranspose( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i], inverse_permutation)); - } - break; - } - case HloOpcode::kRng: { - CHECK_EQ(operands[i]->user_count(), 1); - operands[i] = instruction->parent()->AddInstruction( - operands[i]->CloneWithNewOperands( - ShapeUtil::ChangeElementType(new_elementwise_shape, - element_type), - operands[i]->operands())); - break; - } - case HloOpcode::kReshape: - case HloOpcode::kTranspose: - operands[i] = operands[i]->mutable_operand(0); - break; - default: - LOG(FATAL) << "Unexpected opcode while trying to sink reshapes or " - "transposes."; - } + VLOG(3) << "Updating operand #" << i << ": " + << operands[i]->ToStringNoMetadata(); + operands[i] = UpdateOperand(computation, first_reshape_operand, + new_operand_shape, operands[i]); } if (HloOpcode::kFusion == instruction->opcode()) { // Here we already know `instruction` is elementwise, and no operand is - // implicit broadcast as if it were the operands would not be equivalent - // reshapes, so all the fused instructions have the same dimensions. + // implicit broadcast as if it were the operands would not have easy shape + // changes, so all the fused instructions have the same dimensions. for (const auto& fused_instruction : instruction->fused_instructions()) { Shape* shape = fused_instruction->mutable_shape(); - *shape->mutable_dimensions() = new_elementwise_shape.dimensions(); - *shape->mutable_layout() = new_elementwise_shape.layout(); + *shape->mutable_dimensions() = new_operand_shape.dimensions(); + *shape->mutable_layout() = new_operand_shape.layout(); } } HloInstruction* new_elementwise = @@ -284,12 +290,12 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, // // In this case, convert' should have the same element type as // `convert` and the same dimensions as operands[0]. - ShapeUtil::ChangeElementType(new_elementwise_shape, + ShapeUtil::ChangeElementType(new_operand_shape, instruction->shape().element_type()), operands)); std::unique_ptr new_reshape; - switch (old_reshape->opcode()) { + switch (first_reshape_operand->opcode()) { case HloOpcode::kReshape: VLOG(3) << "Creating new reshape for new elementwise op: " << new_elementwise->ToStringNoMetadata(); @@ -297,8 +303,9 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, HloInstruction::CreateReshape(instruction->shape(), new_elementwise); break; case HloOpcode::kTranspose: - new_reshape = HloInstruction::CreateTranspose( - instruction->shape(), new_elementwise, old_reshape->dimensions()); + new_reshape = + HloInstruction::CreateTranspose(instruction->shape(), new_elementwise, + first_reshape_operand->dimensions()); break; default: LOG(FATAL) << "Bad opcode"; @@ -312,20 +319,17 @@ StatusOr TrySinkReshapeOrTranspose(HloComputation* computation, StatusOr ReshapeMover::Run(HloModule* module) { bool changed = false; - std::vector computations; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - computations.push_back(computation.get()); - } - for (const auto& comp : computations) { + VLOG(2) << "Pre ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); + for (auto* comp : module->MakeNonfusionComputations()) { for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) { TF_ASSIGN_OR_RETURN(bool did_change, TrySinkReshapeOrTranspose(comp, instruction)); changed |= did_change; } } + VLOG(2) << "Post ReshapeMover HLO:"; + XLA_VLOG_LINES(2, module->ToString()); return changed; } diff --git a/tensorflow/compiler/xla/service/reshape_mover.h b/tensorflow/compiler/xla/service/reshape_mover.h index b7e0a46939a10b3376758109214c9722976f50e0..1f59e3b3147facb6f2fae00d6c810bf54d560e5c 100644 --- a/tensorflow/compiler/xla/service/reshape_mover.h +++ b/tensorflow/compiler/xla/service/reshape_mover.h @@ -26,7 +26,7 @@ namespace xla { // them inputward also. class ReshapeMover : public HloPassInterface { public: - tensorflow::StringPiece name() const override { return "reshape-motion"; } + tensorflow::StringPiece name() const override { return "reshape-mover"; } StatusOr Run(HloModule* module) override; }; diff --git a/tensorflow/compiler/xla/service/reshape_mover_test.cc b/tensorflow/compiler/xla/service/reshape_mover_test.cc index a81d3f4eb344510b3973aa46a576d11f613bc404..aac8638a54f744f0c230ec6c5ca071c1daf45ab2 100644 --- a/tensorflow/compiler/xla/service/reshape_mover_test.cc +++ b/tensorflow/compiler/xla/service/reshape_mover_test.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/test.h" #include "tensorflow/compiler/xla/test_helpers.h" -#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/strings/str_util.h" @@ -34,7 +34,7 @@ namespace op = xla::testing::opcode_matchers; namespace xla { namespace { -using ReshapeMoverTest = HloTestBase; +using ReshapeMoverTest = HloVerifiedTestBase; TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { HloComputation::Builder builder(TestName()); @@ -50,13 +50,12 @@ TEST_F(ReshapeMoverTest, ReshapesWithDifferentInputShapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); @@ -89,13 +88,12 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd1ReshapesOnRngNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(rng0), const1)); @@ -115,13 +113,12 @@ TEST_F(ReshapeMoverTest, ScalarReshapesNotMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), @@ -142,12 +139,11 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, param1))); @@ -193,21 +189,19 @@ TEST_F(ReshapeMoverTest, 1ConstantAnd2ReshapesMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param2)); builder.AddInstruction(HloInstruction::CreateTernary( - ShapeUtil::MakeShape(PRED, {2, 3}), HloOpcode::kSelect, const0, reshape1, - reshape2)); + root_shape, HloOpcode::kSelect, const0, reshape1, reshape2)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(const0, reshape1, reshape2)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(op::Reshape(const0), param1, param2))); - EXPECT_EQ(const0->shape().DebugString(), + EXPECT_EQ(root_shape.DebugString(), computation->root_instruction()->shape().DebugString()); } @@ -228,17 +222,16 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { 0, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param0")); auto reshape0 = builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); - auto param1 = builder.AddInstruction(HloInstruction::CreateParameter( - 1, ShapeUtil::MakeShape(F32, {1, 8, 1, 7}), "param1")); + auto param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, root_shape, "param1")); builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), param1)); @@ -260,7 +253,7 @@ TEST_F(ReshapeMoverTest, 1ParameterAnd1ReshapeNotMoved) { // trivial reshapes. TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { HloComputation::Builder builder(TestName()); - auto root_shape = ShapeUtil::MakeShape(F32, {2, 3}); + auto root_shape = ShapeUtil::MakeShape(F32, {3, 2}); auto const0 = builder.AddInstruction(HloInstruction::CreateConstant( Literal::CreateR2({{1, 2, 3}, {4, 5, 6}}))); auto reshape0 = @@ -272,18 +265,17 @@ TEST_F(ReshapeMoverTest, 2TrivialConstantReshapeNotMoved) { builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); auto pred = builder.AddInstruction(HloInstruction::CreateParameter( - 0, ShapeUtil::MakeShape(PRED, {1, 3, 1, 2}), "pred")); + 0, ShapeUtil::MakeShape(PRED, {3, 2}), "pred")); builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(pred, op::Reshape(const0), op::Reshape(const1))); @@ -323,13 +315,12 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, const1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Add(op::Reshape(param0), const1)); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Add(param0, op::Reshape(const1)))); @@ -337,6 +328,48 @@ TEST_F(ReshapeMoverTest, 1NonTrivialReshapeMoved) { computation->root_instruction()->shape().DebugString()); } +// For a graph that looks like: +// +// +- reshape0 - param0 (shape A) +// | +// +- reshape1 - const1 (shape B) +// | +// add +// +// There is 1 non-trivial reshape (reshape0). It's not clear whether reshape1 +// should be trivial or not; conceptually it's trivial, but handling it would +// complicate the rest of our logic. +// +// For now we treat it as non-trivial, so we verify that we don't sink the +// reshapes in this case. +TEST_F(ReshapeMoverTest, 1NonTrivialReshapeWith1ReshapedConstNotMoved) { + HloComputation::Builder builder(TestName()); + auto root_shape = ShapeUtil::MakeShape(F32, {1, 1, 3}); + auto param0 = builder.AddInstruction(HloInstruction::CreateParameter( + 0, ShapeUtil::MakeShape(F32, {1, 3}), "param0")); + auto const1 = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({9, 8, 7}))); + auto reshape0 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, param0)); + auto reshape1 = + builder.AddInstruction(HloInstruction::CreateReshape(root_shape, const1)); + + builder.AddInstruction(HloInstruction::CreateBinary( + root_shape, HloOpcode::kAdd, reshape0, reshape1)); + + auto computation = module().AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); + + EXPECT_THAT(computation->root_instruction(), + op::Add(op::Reshape(param0), op::Reshape(const1))); + EXPECT_EQ(root_shape.DebugString(), + computation->root_instruction()->shape().DebugString()); +} + TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { HloComputation::Builder builder(TestName()); auto root_shape = ShapeUtil::MakeShape(F32, {8, 7}); @@ -351,15 +384,14 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossFusion) { auto add = builder.AddInstruction(HloInstruction::CreateBinary( root_shape, HloOpcode::kAdd, reshape0, reshape1)); - HloModule module(TestName()); - auto computation = module.AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); computation->CreateFusionInstruction({add}, HloInstruction::FusionKind::kLoop); EXPECT_THAT(computation->root_instruction(), op::Fusion(op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(&module).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Fusion(param0, param1))); @@ -386,14 +418,13 @@ TEST_F(ReshapeMoverTest, EquivalentReshapesMovedAcrossSelect) { builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, reshape0, reshape1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Select(op::Reshape(pred), op::Reshape(param0), op::Reshape(param1))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Reshape(op::Select(pred, param0, param1))); @@ -416,12 +447,11 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) { auto select = builder.AddInstruction(HloInstruction::CreateTernary( root_shape, HloOpcode::kSelect, reshape_pred, param0, param1)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Select(op::Reshape(pred), param0, param1)); @@ -468,12 +498,11 @@ TEST_F(ReshapeMoverTest, ImplicitlyBroadcastReshapeIsNotMovedBug37787999) { auto multiply = builder.AddInstruction(HloInstruction::CreateBinary( constant->shape(), HloOpcode::kMultiply, constant, reshape)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); - EXPECT_FALSE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_FALSE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT(computation->root_instruction(), op::Multiply(op::Constant(), op::Reshape(param0))); @@ -517,15 +546,14 @@ TEST_F(ReshapeMoverTest, MultiplePasses) { builder.AddInstruction(HloInstruction::CreateBinary(shape3, HloOpcode::kAdd, reshape2, reshape3)); - auto module = CreateNewModule(); - auto computation = module->AddEntryComputation(builder.Build()); + auto computation = module().AddEntryComputation(builder.Build()); EXPECT_THAT( computation->root_instruction(), op::Add(op::Reshape(param2), op::Reshape(op::Add(op::Reshape(param0), op::Reshape(param1))))); - EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie()); + EXPECT_TRUE(ReshapeMover().Run(&module()).ValueOrDie()); EXPECT_THAT( computation->root_instruction(), diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 049ae91e9308dc8ab89db0328cf8098ca54ef098..71afbee456b0f5eb67cb092d84f8e95ea1038c54 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -153,7 +153,7 @@ int ServiceOptions::intra_op_parallelism_threads() const { Service::Service(const ServiceOptions& options, std::unique_ptr execute_backend) : options_(options), execute_backend_(std::move(execute_backend)) { - CHECK(options_.number_of_replicas() > 0); + CHECK_GT(options_.number_of_replicas(), 0); if (execute_backend_) { if (execute_backend_->device_count() > 0) { CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas()) @@ -187,8 +187,9 @@ tensorflow::Status Service::Computation(const ComputationRequest* arg, *result->mutable_computation() = computation_tracker_.NewComputation(arg->name()); - VLOG(1) << Printf("Created new computation %s on service %p", - result->computation().ShortDebugString().c_str(), this); + VLOG(1) << Printf("Created new computation %s on service %p, name %s", + result->computation().ShortDebugString().c_str(), this, + arg->name().c_str()); return tensorflow::Status::OK(); } @@ -268,7 +269,7 @@ StatusOr> Service::ResolveAndValidateArguments( StatusOr> Service::CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, bool has_hybrid_result) { + const ExecutionOptions* execution_options) { auto config = MakeUnique(program_shape); auto* computation_layout = config->mutable_entry_computation_layout(); @@ -305,7 +306,6 @@ StatusOr> Service::CreateModuleConfig( } config->set_replica_count(options_.number_of_replicas()); - config->set_has_hybrid_result(has_hybrid_result); if (execution_options != nullptr) { config->set_seed(execution_options->seed()); config->set_debug_options(execution_options->debug_options()); @@ -338,7 +338,7 @@ StatusOr>> Service::BuildExecutables( std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors) { + std::vector> executors) { VLOG(1) << Printf("BuildExecutable on service %p", this); // Dump computation proto state if flag is set. @@ -490,14 +490,20 @@ Service::ExecuteParallelAndRegisterResult( std::vector> arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags) { + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile) { // Streams where the computation are launched, so we can wait on the streams // to complete. std::vector::SmartPtr> streams; + std::vector> timers; // Global data handles for the computation results, one for each computation. std::vector result_handles; + // Device ID to stream executor, populated only with devices that are being + // profiled. + std::map index_to_profiled_streams; + TF_ASSIGN_OR_RETURN(DeviceAssignment device_assignment, backend->computation_placer()->AssignDevices( options_.number_of_replicas(), executables.size())); @@ -510,6 +516,21 @@ Service::ExecuteParallelAndRegisterResult( backend->BorrowStream(replicas[replica])); streams.push_back(std::move(stream)); + if (replica == 0 && profile != nullptr) { + timers.emplace_back( + new perftools::gputools::Timer(streams.back()->parent())); + streams.back() + ->InitTimer(timers.back().get()) + .ThenStartTimer(timers.back().get()); + CHECK(timers.front() != nullptr); + } + + if (replica == 0 && + executables[i]->module_config().debug_options().xla_hlo_profile() && + executables[i]->hlo_profiling_enabled()) { + index_to_profiled_streams[i] = streams.back().get(); + } + // Set up run options. ExecutableRunOptions options; options.set_stream(streams.back().get()); @@ -526,6 +547,10 @@ Service::ExecuteParallelAndRegisterResult( perftools::gputools::DeviceMemoryBase result, executables[i]->ExecuteAsyncOnStream(&run_options, arguments[i])); + if (replica == 0 && profile != nullptr) { + streams.back()->ThenStopTimer(timers.back().get()); + } + // All replicas share the same device address for the result allocation, // so only one of the replicas need to register the result handle. if (replica == 0) { @@ -543,6 +568,69 @@ Service::ExecuteParallelAndRegisterResult( } } + // For every stream that had profiling enabled, obtain and debug-dump the HLO + // profile. + for (auto& index_to_profiled_stream : index_to_profiled_streams) { + int64 device = index_to_profiled_stream.first; + se::Stream* stream = index_to_profiled_stream.second; + HloExecutionProfile hlo_profile; + TF_RETURN_IF_ERROR(executables[device]->PopulateExecutionProfile( + &hlo_profile, stream->parent())); + + std::unordered_set profiled_computations = + hlo_profile.profiled_computations(); + // To ensure we have print the profiles in a stable order, iterate over the + // computations in post order. + auto& module = executables[device]->module(); + std::list all_computations = + module.MakeComputationPostOrder(); + for (xla::HloComputation* computation : all_computations) { + if (profiled_computations.count(computation) > 0) { + string profile_string = hlo_profile.ToString( + *computation, streams[0]->parent()->GetDeviceDescription(), + executables[device]->CreateCostAnalysis().get()); + if (!profile_string.empty()) { + LOG(INFO) << "HLO profile for execution on device " << device + << ":\n"; + XLA_LOG_LINES(tensorflow::INFO, profile_string); + } + } + } + hlo_graph_dumper::MaybeDumpHloModule(module, "Service::Execute", + &hlo_profile); + } + + if (profile != nullptr) { + CHECK(!timers.empty()); + std::vector timer_nanoseconds; + timer_nanoseconds.reserve(timers.size()); + for (auto& timer : timers) { + timer_nanoseconds.push_back(timer->Nanoseconds()); + } + uint64 nanoseconds = + *std::max_element(timer_nanoseconds.begin(), timer_nanoseconds.end()); + + // Merge in run-time profile information from execution_profile on the + // zeroth device. + profile->MergeFrom(executables[0]->execution_profile()); + + // Overall execution time (in nanoseconds) from the executor timer. + profile->set_compute_and_transfer_time_ns(nanoseconds); + + // TODO(b/28123297): On GPU we end up including transfer time in + // the compute time this way. Instead, we should get the correct + // value by measuring it. Setting the field here at least lets + // benchmarks provide *some* value for GPU computations. + // + // TODO(b/28447609): The value in compute_and_transfer_time_ns is actually + // the compute time without the transfer time, so this way we get the + // correct compute time. We should instead have the correct value for + // compute_and_transfer_time and set compute_time to the compute time. + if (profile->compute_time_ns() == 0) { + profile->set_compute_time_ns(profile->compute_and_transfer_time_ns()); + } + } + return result_handles; } @@ -615,31 +703,41 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, VLOG(1) << "running execute-parallel request: " << arg->ShortDebugString(); std::vector> all_arguments; - std::vector executors; + std::vector> all_executors; std::vector versioned_handles; std::vector> module_configs; std::vector computation_names; std::vector device_handles; - if (arg->requests_size() * options_.number_of_replicas() > + int num_requested_devices = + std::accumulate(arg->requests().begin(), arg->requests().end(), 0, + [](int a, const ExecuteRequest& r) -> int { + return a + r.execution_options().device_handles_size(); + }); + if (num_requested_devices * options_.number_of_replicas() > execute_backend_->device_count()) { return FailedPrecondition( "there are not enough stream executors to execute %d computations", - arg->requests_size()); + num_requested_devices); } for (int64 i = 0; i < arg->requests_size(); ++i) { // Get the stream executor for the i'th computation. This stream executor // is one of the executors to run the replicated computation. - if (!arg->requests(i).has_device_handle()) { + const ExecutionOptions& execution_options = + arg->requests(i).execution_options(); + if (execution_options.device_handles().empty()) { return FailedPrecondition( "device handles must be given to execute parallel computations"); } - TF_ASSIGN_OR_RETURN( - auto replicas, - Replicas(*execute_backend_, arg->requests(i).device_handle())); - se::StreamExecutor* executor = replicas[0]; - CHECK(executor != nullptr); + std::vector executors; + for (const auto& device_handle : execution_options.device_handles()) { + TF_ASSIGN_OR_RETURN(auto replicas, + Replicas(*execute_backend_, device_handle)); + se::StreamExecutor* executor = replicas[0]; + CHECK(executor != nullptr); + executors.push_back(executor); + } // Resolve the UserComputation object associated with the requested // computation and compute the program shape. @@ -658,10 +756,12 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Resolve the allocations for the arguments of the computation, and create // a vector of device memory offsets for the arguments from the allocations. + // In the case of partitioned computations, assume all arguments go on the + // zeroth core. TF_ASSIGN_OR_RETURN( std::vector arg_allocations, ResolveAndValidateArguments(request.arguments(), execute_backend_.get(), - executor->device_ordinal())); + executors[0]->device_ordinal())); std::vector arguments; arguments.reserve(arg_allocations.size()); for (const Allocation* allocation : arg_allocations) { @@ -678,11 +778,15 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Adds to the vectors to build and execute the computations after the loop. all_arguments.push_back(arguments); + all_arguments.insert(all_arguments.end(), executors.size() - 1, {}); versioned_handles.push_back(versioned_handle); module_configs.push_back(std::move(module_config)); - computation_names.push_back(user_computation->name()); - executors.push_back(executor); - device_handles.push_back(arg->requests(i).device_handle()); + computation_names.insert(computation_names.end(), executors.size(), + user_computation->name()); + all_executors.push_back(executors); + device_handles.insert(device_handles.end(), + execution_options.device_handles().begin(), + execution_options.device_handles().end()); } // Build the user computations into HloModules and compile to generate the @@ -690,7 +794,7 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, TF_ASSIGN_OR_RETURN( std::vector> executables, BuildExecutables(versioned_handles, std::move(module_configs), - execute_backend_.get(), executors)); + execute_backend_.get(), all_executors)); std::vector executable_ptrs; executable_ptrs.reserve(executables.size()); for (const auto& executable : executables) { @@ -699,14 +803,16 @@ tensorflow::Status Service::ExecuteParallel(const ExecuteParallelRequest* arg, // Execute the generated executables in parallel and return the device // handles for each computation's output. + ExecutionProfile profile; TF_ASSIGN_OR_RETURN( std::vector outputs, ExecuteParallelAndRegisterResult(executable_ptrs, all_arguments, execute_backend_.get(), device_handles, - computation_names)); + computation_names, &profile)); for (const GlobalDataHandle& output : outputs) { ExecuteResponse response; *response.mutable_output() = output; + *response.mutable_profile() = profile; *result->add_responses() = response; } @@ -752,6 +858,17 @@ tensorflow::Status Service::Execute(const ExecuteRequest* arg, return InvalidArgument("computations may not be empty"); } + // If we received multiple device handles, we must partition the module. + if (arg->execution_options().device_handles_size() > 1) { + ExecuteParallelRequest parallel_arg; + *parallel_arg.add_requests() = *arg; + ExecuteParallelResponse parallel_result; + TF_RETURN_IF_ERROR(ExecuteParallel(¶llel_arg, ¶llel_result)); + TF_RET_CHECK(parallel_result.responses_size() > 0); + *result = parallel_result.responses(0); + return Status::OK(); + } + TF_ASSIGN_OR_RETURN( std::shared_ptr program_shape, user_computation->ComputeProgramShape(versioned_handle.version)); @@ -1055,8 +1172,9 @@ tensorflow::Status Service::IsConstant(const IsConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->num_parameters())); result->set_is_constant(is_constant); return tensorflow::Status::OK(); @@ -1074,8 +1192,9 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, return InvalidArgument("computations may not be empty"); } - TF_ASSIGN_OR_RETURN(bool is_constant, - user_computation->IsConstant(arg->operand())); + TF_ASSIGN_OR_RETURN( + bool is_constant, + user_computation->IsConstant(arg->operand(), arg->parameters_size())); if (!is_constant) { return InvalidArgument("Operand to ComputeConstant depends on parameter."); } @@ -1114,8 +1233,18 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg, /*include_unreachable_instructions=*/ false)); + std::vector parameters(arg->parameters_size()); + for (int64 i = 0; i < arg->parameters_size(); ++i) { + parameters[i] = Literal(arg->parameters(i)); + } + std::vector parameter_ptrs; + std::transform(parameters.begin(), parameters.end(), + std::back_inserter(parameter_ptrs), + [](const Literal& literal) { return &literal; }); + HloEvaluator evaluator; - TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {})); + TF_ASSIGN_OR_RETURN(auto result_literal, + evaluator.Evaluate(*module, parameter_ptrs)); // Since the shape_with_output_layout option in ExecutionOption is // non-effective to the Evaluator results, explicit relayout here. if (arg->has_output_layout()) { @@ -1388,9 +1517,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { // proto in the above switch statement. TF_ASSIGN_OR_RETURN(ComputationDataHandle handle, handle_status); TF_RETURN_IF_ERROR(computation->SetOpMetadata(handle, arg->metadata())); - TF_RETURN_IF_ERROR( - computation->SetOpDeviceAssignment(handle, arg->device_assignment())); - + if (arg->has_sharding()) { + TF_RETURN_IF_ERROR(computation->SetOpSharding(handle, arg->sharding())); + } return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/service.h b/tensorflow/compiler/xla/service/service.h index bb86a53c62e05bb62b93bbac88c2ca251ad0439a..6646be2e9aa43763b93bcea7a1df9d10580f162c 100644 --- a/tensorflow/compiler/xla/service/service.h +++ b/tensorflow/compiler/xla/service/service.h @@ -277,8 +277,7 @@ class Service : public ServiceInterface { StatusOr> CreateModuleConfig( const ProgramShape& program_shape, tensorflow::gtl::ArraySlice argument_shapes, - const ExecutionOptions* execution_options, - bool has_hybrid_result = false); + const ExecutionOptions* execution_options); // Builds an Executable for the given parameters. StatusOr> BuildExecutable( @@ -294,7 +293,7 @@ class Service : public ServiceInterface { std::vector versioned_handles, std::vector> module_configs, Backend* backend, - std::vector executors); + std::vector> executors); // Similar to BuildExecutable, but look in the compilation cache for the // executable first. If the executable is not in the cache, it is built and @@ -328,7 +327,8 @@ class Service : public ServiceInterface { arguments, Backend* backend, tensorflow::gtl::ArraySlice device_handles, - tensorflow::gtl::ArraySlice result_tags); + tensorflow::gtl::ArraySlice result_tags, + ExecutionProfile* profile); // Convenience function for adding a function to a user computation. template diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc index 5178a750b90a16f8c5674a72072e7d0d9f9302e7..791d17365b1d756714b5feb0439e6919d9f23edc 100644 --- a/tensorflow/compiler/xla/service/shape_inference.cc +++ b/tensorflow/compiler/xla/service/shape_inference.cc @@ -53,14 +53,18 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { return UNOP_EXP; case HloOpcode::kFloor: return UNOP_FLOOR; + case HloOpcode::kImag: + return UNOP_IMAG; case HloOpcode::kIsFinite: return UNOP_IS_FINITE; case HloOpcode::kLog: return UNOP_LOG; - case HloOpcode::kLogicalNot: - return UNOP_LOGICAL_NOT; + case HloOpcode::kNot: + return UNOP_NOT; case HloOpcode::kNegate: return UNOP_NEGATE; + case HloOpcode::kReal: + return UNOP_REAL; case HloOpcode::kRoundNearestAfz: return UNOP_ROUND_NEAREST_AFZ; case HloOpcode::kSign: @@ -81,6 +85,10 @@ UnaryOperation OpcodeToUnaryOperation(HloOpcode opcode) { // opcode. BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { switch (opcode) { + case HloOpcode::kAtan2: + return BINOP_ATAN2; + case HloOpcode::kComplex: + return BINOP_COMPLEX; case HloOpcode::kDot: return BINOP_DOT; case HloOpcode::kMultiply: @@ -89,8 +97,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_ADD; case HloOpcode::kSubtract: return BINOP_SUB; - case HloOpcode::kIndex: - return BINOP_INDEX; case HloOpcode::kDivide: return BINOP_DIV; case HloOpcode::kEq: @@ -113,10 +119,16 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) { return BINOP_POW; case HloOpcode::kRemainder: return BINOP_REM; - case HloOpcode::kLogicalOr: - return BINOP_LOGICAL_OR; - case HloOpcode::kLogicalAnd: - return BINOP_LOGICAL_AND; + case HloOpcode::kOr: + return BINOP_OR; + case HloOpcode::kAnd: + return BINOP_AND; + case HloOpcode::kShiftLeft: + return BINOP_SHIFT_LEFT; + case HloOpcode::kShiftRightArithmetic: + return BINOP_SHIFT_RIGHT_ARITHMETIC; + case HloOpcode::kShiftRightLogical: + return BINOP_SHIFT_RIGHT_LOGICAL; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -130,8 +142,6 @@ TernaryOperation OpcodeToTernaryOperation(HloOpcode opcode) { return TRIOP_CLAMP; case HloOpcode::kSelect: return TRIOP_SELECT; - case HloOpcode::kUpdate: - return TRIOP_UPDATE; default: LOG(FATAL) << "unhandled opcode " << opcode; } @@ -303,30 +313,53 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, switch (operation) { case UNOP_FLOOR: case UNOP_CEIL: + if (!ShapeUtil::ElementIsFloating(arg)) { + return InvalidArgument( + "expected element type in shape to be floating for floor/ceil " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return arg; case UNOP_COS: case UNOP_SIN: case UNOP_EXP: case UNOP_LOG: case UNOP_TANH: - if (!ShapeUtil::ElementIsFloating(arg)) { + if (!ShapeUtil::ElementIsFloating(arg) && + !ShapeUtil::ElementIsComplex(arg)) { return InvalidArgument( - "expected element type in shape to be floating for exp/log/tanh " - "operation; got %s", + "expected element type in shape to be floating or complex for " + "sin/cos/exp/log/tanh operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; + case UNOP_REAL: + case UNOP_IMAG: + if (!ShapeUtil::ElementIsComplex(arg)) { + return InvalidArgument( + "expected element type in shape to be complex for real/imag " + "operation; got %s", + PrimitiveType_Name(arg.element_type()).c_str()); + } + return ShapeUtil::ChangeElementType(arg, F32); case UNOP_ABS: + if (ShapeUtil::ElementIsComplex(arg)) { + return ShapeUtil::ChangeElementType( + arg, primitive_util::ComplexComponentType(arg.element_type())); + } + return arg; case UNOP_NEGATE: case UNOP_ROUND_NEAREST_AFZ: case UNOP_SIGN: case UNOP_SORT: return arg; - case UNOP_LOGICAL_NOT: - if (arg.element_type() != PRED) { + case UNOP_NOT: + if (arg.element_type() != PRED && + !primitive_util::IsIntegralType(arg.element_type())) { return InvalidArgument( - "expected pred element type in argument to logical-not operation; " - "got %s", + "expected pred or an integral element type in argument to not " + "operation; got %s", PrimitiveType_Name(arg.element_type()).c_str()); } return arg; @@ -457,7 +490,10 @@ StatusOr InferWindowOutputShape(const Shape& base_shape, } if (ShapeUtil::Rank(operand_shape) != padding_config.dimensions_size()) { return InvalidArgument( - "the rank of the operand and the padding configuration do not match."); + "The rank of the operand and the padding configuration do not match: " + "%s vs %s", + ShapeUtil::HumanString(operand_shape).c_str(), + padding_config.ShortDebugString().c_str()); } if (operand_shape.element_type() != padding_value_shape.element_type()) { return InvalidArgument( @@ -679,11 +715,15 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( ShapeUtil::HumanString(rhs).c_str()); } - if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs) && - !broadcast_dimensions.empty()) { - return InvalidArgument( - "broadcast dimensions field should not be set on binary " - "operations with operands of the same rank"); + if (ShapeUtil::Rank(lhs) == ShapeUtil::Rank(rhs)) { + std::vector identity_dims(ShapeUtil::Rank(lhs)); + std::iota(identity_dims.begin(), identity_dims.end(), 0); + if (!broadcast_dimensions.empty() && + broadcast_dimensions != identity_dims) { + return InvalidArgument( + "broadcast dimensions field must either be not set or be the " + "identity on binary operations with operands of the same rank"); + } } if (ShapeUtil::Compatible(lhs, rhs)) { @@ -739,24 +779,44 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( case BINOP_MIN: case BINOP_SUB: case BINOP_ADD: + case BINOP_ATAN2: case BINOP_POW: case BINOP_DIV: case BINOP_REM: case BINOP_MUL: + case BINOP_SHIFT_LEFT: + case BINOP_SHIFT_RIGHT_ARITHMETIC: + case BINOP_SHIFT_RIGHT_LOGICAL: return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_LOGICAL_AND: - case BINOP_LOGICAL_OR: - if (lhs.element_type() != PRED) { + case BINOP_COMPLEX: { + if (!ShapeUtil::ElementIsFloating(lhs)) { return InvalidArgument( - "expected pred element type in argument to logical and/or " + "expected element type in shape to be floating for complex compose " "operation; got %s", PrimitiveType_Name(lhs.element_type()).c_str()); } + TF_ASSIGN_OR_RETURN(const Shape& shape, + InferElementwiseBinaryOpShape(operation, lhs, rhs, + broadcast_dimensions)); + if (lhs.element_type() == F32) { + return ShapeUtil::ChangeElementType(shape, C64); + } else { + return Unimplemented("complex component type not supported"); + } + } + case BINOP_AND: + case BINOP_OR: + if (lhs.element_type() != PRED && + !primitive_util::IsIntegralType(lhs.element_type())) { + return InvalidArgument( + "expected pred or integral type in argument to and/or operation; " + "got %s", + PrimitiveType_Name(lhs.element_type()).c_str()); + } return InferElementwiseBinaryOpShape(operation, lhs, rhs, broadcast_dimensions); - case BINOP_EQ: case BINOP_GE: case BINOP_GT: @@ -768,17 +828,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( broadcast_dimensions)); return ShapeUtil::ChangeElementType(shape, PRED); } - case BINOP_INDEX: - if (ShapeUtil::Rank(lhs) > 0 && ShapeUtil::Rank(rhs) == 0) { - tensorflow::gtl::ArraySlice dimensions = - AsInt64Slice(lhs.dimensions()); - dimensions.pop_front(); - return ShapeUtil::MakeShape(lhs.element_type(), dimensions); - } - return Unimplemented("cannot infer shape for operation: %s <%s> %s", - ShapeUtil::HumanString(lhs).c_str(), - BinaryOperation_Name(operation).c_str(), - ShapeUtil::HumanString(rhs).c_str()); default: return Unimplemented( "not yet implemented; infer binary op shape: %s; lhs: %s; rhs: %s", @@ -805,14 +854,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( return InferClampShape(lhs, rhs, ehs); case TRIOP_SELECT: return InferSelectShape(lhs, rhs, ehs); - case TRIOP_UPDATE: - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(lhs, "lhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(rhs, "rhs of ternary operation")); - TF_RETURN_IF_ERROR( - ExpectNotTupleOrOpaque(ehs, "ehs of ternary operation")); - return lhs; default: return InvalidArgument("unknown operation %s", TernaryOperation_Name(operation).c_str()); @@ -852,7 +893,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /* static */ StatusOr ShapeInference::InferMapShape( tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply) { + const ProgramShape& to_apply, + tensorflow::gtl::ArraySlice dimensions) { if (arg_shapes.empty()) { return InvalidArgument("Map expects at least one argument"); } @@ -888,6 +930,24 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( tensorflow::str_util::Join(pieces, ", ").c_str()); } + // Check that dimensions.size == arg_shape.dimensions_size() (we currently + // only support mapping across all dimensions: i.e. scalar map functions). + if (dimensions.size() != arg_shape->dimensions_size()) { + return InvalidArgument( + "Map applied to a subset of dimensions currently not supported: " + "arg_dimension_size: %d, requested_map_dimensions_size: %zu", + arg_shape->dimensions_size(), dimensions.size()); + } + + // Check that requested map dimensions numbers are monotonically increasing. + for (int i = 0; i < dimensions.size(); ++i) { + if (dimensions[i] != i) { + return InvalidArgument( + "Map requires monotonically increasing dimension numbers, found: %s ", + tensorflow::str_util::Join(dimensions, ", ").c_str()); + } + } + // The applied function's arity equals the number of arguments. if (arg_shapes.size() != to_apply.parameters_size()) { return InvalidArgument( @@ -1349,14 +1409,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( "Window: %s", window.DebugString().c_str()); } - int num_spatial_dims = dnums.spatial_dimensions_size(); - if (num_spatial_dims < 1) { - return InvalidArgument( - "Convolution requires at least one spatial dimension.\n" - "Window: %s", - window.DebugString().c_str()); - } + const int num_spatial_dims = dnums.spatial_dimensions_size(); if (window.dimensions_size() != num_spatial_dims) { return InvalidArgument( "Window must have same number of dimensions as dimension numbers.\n" @@ -1364,7 +1418,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( window.DebugString().c_str(), dnums.DebugString().c_str()); } - int num_dims = num_spatial_dims + 2; + const int num_dims = num_spatial_dims + 2; if (ShapeUtil::Rank(lhs) != num_dims) { return InvalidArgument( "The LHS argument to a convolution should have rank %d.\n" @@ -1383,8 +1437,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( // Verifies that the input and window dimensions are a permutation of // the dimension numbers. std::vector input_dnums(num_dims); - input_dnums[0] = dnums.batch_dimension(); - input_dnums[1] = dnums.feature_dimension(); + input_dnums[0] = dnums.input_batch_dimension(); + input_dnums[1] = dnums.input_feature_dimension(); std::copy(dnums.spatial_dimensions().begin(), dnums.spatial_dimensions().end(), input_dnums.begin() + 2); std::sort(input_dnums.begin(), input_dnums.end()); @@ -1424,8 +1478,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( for (int i = 0; i < num_spatial_dims; ++i) { input_spatial_dims[i] = lhs.dimensions(dnums.spatial_dimensions(i)); } - const int64 input_features = lhs.dimensions(dnums.feature_dimension()); - const int64 input_batch = lhs.dimensions(dnums.batch_dimension()); + const int64 input_features = lhs.dimensions(dnums.input_feature_dimension()); + const int64 input_batch = lhs.dimensions(dnums.input_batch_dimension()); std::vector kernel_spatial_dims(num_spatial_dims); for (int i = 0; i < num_spatial_dims; ++i) { @@ -1467,8 +1521,8 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( /*allow_negative_padding=*/true)); std::vector dimensions(num_dims); - dimensions[dnums.batch_dimension()] = input_batch; - dimensions[dnums.feature_dimension()] = kernel_output_features; + dimensions[dnums.output_batch_dimension()] = input_batch; + dimensions[dnums.output_feature_dimension()] = kernel_output_features; for (int i = 0; i < num_spatial_dims; ++i) { dimensions[dnums.spatial_dimensions(i)] = window_output_shape.dimensions(i); } @@ -1871,11 +1925,16 @@ ShapeInference::InferDegenerateDimensionBroadcastShape( Shape inferred_shape = ShapeUtil::MakeShape(operand.element_type(), new_sizes); + VLOG(3) << "Reshape inferred shape: " + << ShapeUtil::HumanString(inferred_shape); if (ShapeUtil::ElementsIn(operand) != ShapeUtil::ElementsIn(inferred_shape)) { return InvalidArgument( - "reshape operation has mismatched element counts: from=%lld to=%lld", - ShapeUtil::ElementsIn(operand), ShapeUtil::ElementsIn(inferred_shape)); + "reshape operation has mismatched element counts: from=%lld (%s) " + "to=%lld (%s)", + ShapeUtil::ElementsIn(operand), ShapeUtil::HumanString(operand).c_str(), + ShapeUtil::ElementsIn(inferred_shape), + ShapeUtil::HumanString(inferred_shape).c_str()); } std::vector indices(ShapeUtil::Rank(operand)); diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h index 379feef5e45384a4bc436ae6a1e71930878da121..d5d497176d6c340d8c8f34cdacf6a9e32040c387 100644 --- a/tensorflow/compiler/xla/service/shape_inference.h +++ b/tensorflow/compiler/xla/service/shape_inference.h @@ -78,7 +78,8 @@ class ShapeInference { // to the given operand shapes. static StatusOr InferMapShape( tensorflow::gtl::ArraySlice arg_shapes, - const ProgramShape& to_apply); + const ProgramShape& to_apply, + tensorflow::gtl::ArraySlice dimensions); // Infers the shape produced by InferBatchNormTraining with the given // operands. diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc index 8c731ae2976fd3da275a5c9596a4ac7f738e5fbc..d12f7bd1453890db3280e54719a6ce811006336d 100644 --- a/tensorflow/compiler/xla/service/shape_inference_test.cc +++ b/tensorflow/compiler/xla/service/shape_inference_test.cc @@ -35,6 +35,7 @@ class ShapeInferenceTest : public ::testing::Test { // Some handy scalar shapes. const Shape s32_ = ShapeUtil::MakeShape(S32, {}); const Shape f32_ = ShapeUtil::MakeShape(F32, {}); + const Shape f64_ = ShapeUtil::MakeShape(F64, {}); const Shape pred_ = ShapeUtil::MakeShape(PRED, {}); // Some handy vector and matrix shapes of F32 type. @@ -251,6 +252,44 @@ TEST_F(ShapeInferenceTest, ClampBadShapes) { .ok()); } +TEST_F(ShapeInferenceTest, Complex) { + auto complex_shape = [&](const Shape& lhs, const Shape& rhs, + const tensorflow::gtl::ArraySlice& bcast) { + return ShapeInference::InferBinaryOpShape(BinaryOperation::BINOP_COMPLEX, + lhs, rhs, bcast); + }; + // Inputs must be FP. + ASSERT_FALSE(complex_shape(s32_, s32_, {}).ok()); + ASSERT_FALSE(complex_shape(pred_, pred_, {}).ok()); + // Component types must match. + ASSERT_FALSE(complex_shape(f32_, f64_, {}).ok()); + // Only F32->C64 supported. + ASSERT_FALSE(complex_shape(f64_, f64_, {}).ok()); + // Validate correct uses. + Shape c64_32 = ShapeUtil::MakeShape(C64, {32}); + TF_ASSERT_OK_AND_ASSIGN(Shape result, complex_shape(f32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, ShapeUtil::MakeShape(C64, {}))); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(f32_, vector_32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(vector_32_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32)); + + Shape c64_32_64 = ShapeUtil::MakeShape(C64, {32, 64}); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(vector_64_, matrix_32_64_, {1})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(matrix_32_64_, vector_64_, {1})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, + complex_shape(matrix_32_64_, matrix_32_64_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); + TF_ASSERT_OK_AND_ASSIGN(result, complex_shape(matrix_32_64_, f32_, {})); + ASSERT_TRUE(ShapeUtil::Equal(result, c64_32_64)); +} + TEST_F(ShapeInferenceTest, VariadicOpTuplify) { StatusOr result = ShapeInference::InferVariadicOpShape( VariadicOperation::VAROP_TUPLE, {&s32_, &f32_}); @@ -352,8 +391,10 @@ TEST_F(ShapeInferenceTest, Convolve) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -392,8 +433,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithWindowDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 103, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -433,8 +476,10 @@ TEST_F(ShapeInferenceTest, ConvolveWithBaseDilation) { // Dimension order: batch, feature, x0, x1 Shape lhs_shape = ShapeUtil::MakeShape(F32, {10, 11, 3, 4}); - dnums.set_batch_dimension(0); - dnums.set_feature_dimension(1); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); + dnums.set_input_feature_dimension(1); + dnums.set_output_feature_dimension(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); @@ -475,8 +520,10 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { Shape rhs_shape = ShapeUtil::MakeShape(F32, {12, 11, 3, 2}); ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(3); - dnums.set_feature_dimension(2); + dnums.set_input_batch_dimension(3); + dnums.set_output_batch_dimension(3); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_spatial_dimensions(0); dnums.add_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(0); // duplicated with kernel_x0 @@ -505,7 +552,7 @@ TEST_F(ShapeInferenceTest, ConvolveDimensionNumbersOverlapError) { TEST_F(ShapeInferenceTest, MapThatChangesElementType) { Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, s32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); Shape expected = ShapeUtil::MakeShape(S32, {20}); EXPECT_TRUE(ShapeUtil::Equal(expected, inferred_status.ValueOrDie())); @@ -514,91 +561,92 @@ TEST_F(ShapeInferenceTest, MapThatChangesElementType) { TEST_F(ShapeInferenceTest, Map) { auto inferred_status_r1f32 = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32.status()); EXPECT_TRUE(ShapeUtil::Equal(vector_32_, inferred_status_r1f32.ValueOrDie())); // It's OK to provide a single argument, as long as the applied arity matches // (this degenerates to a Map). auto inferred_status_r1f32_one = ShapeInference::InferMapShape( - {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), {0}); EXPECT_IS_OK(inferred_status_r1f32_one.status()); EXPECT_TRUE( ShapeUtil::Equal(vector_32_, inferred_status_r1f32_one.ValueOrDie())); auto inferred_status_r2s32 = ShapeInference::InferMapShape( {&s32matrix_64_64_, &s32matrix_64_64_, &s32matrix_64_64_}, - ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_)); + ShapeUtil::MakeProgramShape({s32_, s32_, s32_}, s32_), {0, 1}); EXPECT_IS_OK(inferred_status_r2s32.status()); EXPECT_TRUE( ShapeUtil::Equal(s32matrix_64_64_, inferred_status_r2s32.ValueOrDie())); auto no_args_error = ShapeInference::InferMapShape( - {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {}); ASSERT_FALSE(no_args_error.ok()); ASSERT_THAT(no_args_error.status().error_message(), HasSubstr("expects at least one argument")); auto args_diff_shapes_error = ShapeInference::InferMapShape( {&vector_32_, &vector_64_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(args_diff_shapes_error.ok()); ASSERT_THAT(args_diff_shapes_error.status().error_message(), HasSubstr("requires all operands to have the same shape")); auto arity_error = ShapeInference::InferMapShape( - {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_)); + {&vector_32_, &vector_32_}, ShapeUtil::MakeProgramShape({f32_}, f32_), + {0}); ASSERT_FALSE(arity_error.ok()); ASSERT_THAT(arity_error.status().error_message(), HasSubstr("function arity must match")); auto output_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_)); + ShapeUtil::MakeProgramShape({f32_, f32_}, vector_32_), {0}); ASSERT_FALSE(output_shape_error.ok()); ASSERT_THAT(output_shape_error.status().error_message(), HasSubstr("result has to be a scalar")); auto param_shape_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_)); + ShapeUtil::MakeProgramShape({vector_32_, f32_}, f32_), {0}); ASSERT_FALSE(param_shape_error.ok()); ASSERT_THAT(param_shape_error.status().error_message(), HasSubstr("parameter has to be a scalar")); auto param_element_type_error = ShapeInference::InferMapShape( {&vector_32_, &vector_32_}, - ShapeUtil::MakeProgramShape({f32_, s32_}, f32_)); + ShapeUtil::MakeProgramShape({f32_, s32_}, f32_), {0}); ASSERT_FALSE(param_element_type_error.ok()); ASSERT_THAT(param_element_type_error.status().error_message(), HasSubstr("parameter type has to match argument")); Shape arg = ShapeUtil::MakeShape(F32, {20}); ProgramShape to_apply = ShapeUtil::MakeProgramShape({f32_}, f32_); - auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply); + auto inferred_status = ShapeInference::InferMapShape({&arg}, to_apply, {0}); EXPECT_IS_OK(inferred_status.status()); EXPECT_TRUE(ShapeUtil::Equal(arg, inferred_status.ValueOrDie())); auto inferred_status_error1 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_, f32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error1.ok()); ASSERT_THAT(inferred_status_error1.status().error_message(), HasSubstr("arity must match number of arguments")); auto inferred_status_error2 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_)); + {&arg}, ShapeUtil::MakeProgramShape({vector_32_}, f32_), {0}); ASSERT_FALSE(inferred_status_error2.ok()); ASSERT_THAT(inferred_status_error2.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error3 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_)); + {&arg}, ShapeUtil::MakeProgramShape({f32_}, vector_32_), {0}); ASSERT_FALSE(inferred_status_error3.ok()); ASSERT_THAT(inferred_status_error3.status().error_message(), HasSubstr("has to be a scalar")); auto inferred_status_error5 = ShapeInference::InferMapShape( - {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_)); + {&arg}, ShapeUtil::MakeProgramShape({s32_}, s32_), {0}); ASSERT_FALSE(inferred_status_error5.ok()); ASSERT_THAT(inferred_status_error5.status().error_message(), HasSubstr("parameter type has to match argument")); diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc index 865be1b84f2fd599b68f09fdad0323076e637906..a2a442eb1a33d976a114f68d112a7d8f3b540f4b 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.cc +++ b/tensorflow/compiler/xla/service/shaped_buffer.cc @@ -21,98 +21,61 @@ limitations under the License. #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/transfer_manager.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/platform/logging.h" -namespace xla { +namespace se = ::perftools::gputools; -/* static */ StatusOr> -ShapedBuffer::MakeShapedBuffer(const Shape& shape, - const perftools::gputools::Platform* platform, - int device_ordinal) { - if (!LayoutUtil::HasLayout(shape)) { - return InvalidArgument("Shape must have a layout: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); - return WrapUnique(new ShapedBuffer(shape, platform, device_ordinal)); -} +namespace xla { /* static */ StatusOr> -ShapedBuffer::MakeArrayShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer) { +ShapedBuffer::MakeArrayShapedBuffer(const Shape& shape, + const se::Platform* platform, + int device_ordinal, + const se::DeviceMemoryBase& buffer) { if (ShapeUtil::IsTuple(shape)) { return InvalidArgument("Shape must be an array: %s", ShapeUtil::HumanStringWithLayout(shape).c_str()); } - TF_ASSIGN_OR_RETURN(std::unique_ptr shaped_buffer, - MakeShapedBuffer(shape, platform, device_ordinal)); + auto shaped_buffer = + MakeUnique(shape, platform, device_ordinal); *shaped_buffer->mutable_shape_index_to_buffer_entry()->mutable_element({}) = 0; *shaped_buffer->mutable_buffers() = {buffer}; return std::move(shaped_buffer); } -/* static */ StatusOr> -ShapedBuffer::MakeUnnestedTupleShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, - const tensorflow::gtl::ArraySlice - buffers) { - if (!ShapeUtil::IsTuple(shape) || ShapeUtil::IsNestedTuple(shape)) { - return InvalidArgument("Shape must be an unnested tuple: %s", - ShapeUtil::HumanStringWithLayout(shape).c_str()); - } - if (buffers.size() != ShapeUtil::TupleElementCount(shape)) { - return InvalidArgument("Tuple has %lld elements, but %zu buffers given", - ShapeUtil::TupleElementCount(shape), buffers.size()); - } - TF_ASSIGN_OR_RETURN(std::unique_ptr shaped_buffer, - MakeShapedBuffer(shape, platform, device_ordinal)); - shaped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement( - [&shaped_buffer](const ShapeIndex& index, size_t* buffer_element) { - if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { - CHECK_EQ(index.size(), 1); - *buffer_element = index[0]; - } - }); - shaped_buffer->mutable_buffers()->reserve(buffers.size()); - for (const perftools::gputools::DeviceMemoryBase& memory_base : buffers) { - shaped_buffer->mutable_buffers()->push_back(memory_base); - } - return std::move(shaped_buffer); -} - -ShapedBuffer::ShapedBuffer(const Shape& shape, - const perftools::gputools::Platform* platform, +ShapedBuffer::ShapedBuffer(const Shape& shape, const se::Platform* platform, int device_ordinal) : shape_(shape), - shape_index_to_buffer_entry_(shape), platform_(platform), - device_ordinal_(device_ordinal) {} + device_ordinal_(device_ordinal), + shape_index_to_buffer_entry_(shape) {} -const perftools::gputools::DeviceMemoryBase& ShapedBuffer::buffer( +void ShapedBuffer::clear() { + for (se::DeviceMemoryBase& memory_base : buffers_) { + // A default constructed DeviceMemoryBase is a null pointer. + memory_base = se::DeviceMemoryBase(); + } +} + +const se::DeviceMemoryBase& ShapedBuffer::buffer( const ShapeIndex& index) const { - // Buffer are only set at the leaves (array elements of the shape). - CHECK(shape_index_to_buffer_entry_.IsLeaf(index)); return buffers_[shape_index_to_buffer_entry_.element(index)]; } -perftools::gputools::DeviceMemoryBase* ShapedBuffer::mutable_buffer( - const ShapeIndex& index) { - // Buffer are only set at the leaves (array elements of the shape). - CHECK(shape_index_to_buffer_entry_.IsLeaf(index)); +se::DeviceMemoryBase* ShapedBuffer::mutable_buffer(const ShapeIndex& index) { return &buffers_[shape_index_to_buffer_entry_.element(index)]; } /* static */ StatusOr> -ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape, - DeviceMemoryAllocator* allocator, - int device_ordinal) { +ScopedShapedBuffer::Allocate(const Shape& shape, + DeviceMemoryAllocator* allocator, + int device_ordinal) { if (!LayoutUtil::HasLayout(shape)) { return InvalidArgument("Shape must have a layout: %s", ShapeUtil::HumanStringWithLayout(shape).c_str()); @@ -121,28 +84,71 @@ ScopedShapedBuffer::MakeScopedShapedBuffer(const Shape& shape, auto shaped_buffer = WrapUnique(new ScopedShapedBuffer(shape, allocator, device_ordinal)); - // Allocate an appropriate sized buffer for each array element in the shape. - TF_RETURN_IF_ERROR( - shaped_buffer->shape_index_to_buffer_entry_ - .ForEachMutableElementWithStatus([&shaped_buffer]( - const ShapeIndex& index, - size_t* buffer_entry) - -> tensorflow::Status { - if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { - TF_ASSIGN_OR_RETURN( - perftools::gputools::DeviceMemoryBase memory_base, - shaped_buffer->allocator_->Allocate( - shaped_buffer->device_ordinal(), - ShapeUtil::ByteSizeOf(ShapeUtil::GetSubshape( - shaped_buffer->shape(), index)))); - shaped_buffer->buffers_.push_back(memory_base); - *buffer_entry = shaped_buffer->buffers_.size() - 1; - } - return tensorflow::Status::OK(); - })); + // Allocate an appropriate sized buffer for each element in the shape + // including the tuple pointer arrays. Gather tuple element addresses in + // 'element_addresses'. These will be written in the respective tuple's array + // of pointers on the device. + TF_ASSIGN_OR_RETURN(TransferManager * transfer_manager, + TransferManager::GetForPlatform(allocator->platform())); + ShapeTree> element_addresses(shape); + for (auto& pair : shaped_buffer->shape_index_to_buffer_entry_) { + const ShapeIndex& index = pair.first; + size_t& buffer_entry = pair.second; + TF_ASSIGN_OR_RETURN( + se::DeviceMemoryBase memory_base, + shaped_buffer->allocator_->Allocate( + shaped_buffer->device_ordinal(), + transfer_manager->GetByteSizeRequirement( + ShapeUtil::GetSubshape(shaped_buffer->shape(), index)))); + shaped_buffer->buffers_.push_back(memory_base); + buffer_entry = shaped_buffer->buffers_.size() - 1; + + // If this is a tuple element, then push the address on to the + // vector of tuple element addresses. + if (!index.empty()) { + ShapeIndex parent_index = index; + parent_index.pop_back(); + element_addresses.mutable_element(parent_index)->push_back(memory_base); + } + } + + // Fill in the tuple pointer arrays with the addresses of their respective + // elements. + TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor, + allocator->platform()->ExecutorForDevice( + shaped_buffer->device_ordinal())); + for (const auto& pair : element_addresses) { + const ShapeIndex& index = pair.first; + const std::vector& addresses = pair.second; + const Shape& subshape = ShapeUtil::GetSubshape(shape, index); + + if (addresses.empty()) { + TF_RET_CHECK(!ShapeUtil::IsTuple(subshape) || + ShapeUtil::TupleElementCount(subshape) == 0); + continue; + } + TF_RET_CHECK(ShapeUtil::IsTuple(subshape)); + TF_RETURN_IF_ERROR(transfer_manager->WriteTuplePointersToDevice( + executor, addresses, subshape, shaped_buffer->mutable_buffer(index))); + } + return std::move(shaped_buffer); } +/* static */ +StatusOr> ScopedShapedBuffer::MakeScoped( + ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator) { + auto scoped_buffer = WrapUnique(new ScopedShapedBuffer( + shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal())); + scoped_buffer->buffers_ = shaped_buffer->buffers(); + scoped_buffer->shape_index_to_buffer_entry_ = + shaped_buffer->shape_index_to_buffer_entry(); + + shaped_buffer->clear(); + + return std::move(scoped_buffer); +} + ScopedShapedBuffer::ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal) @@ -154,7 +160,7 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { // in the shape (eg, a tuple with a repeated element) so keep track of what // has been deallocated. std::set deallocated_opaques; - for (perftools::gputools::DeviceMemoryBase& memory_base : buffers_) { + for (se::DeviceMemoryBase& memory_base : buffers_) { if (!memory_base.is_null() && deallocated_opaques.count(memory_base.opaque()) == 0) { deallocated_opaques.insert(memory_base.opaque()); @@ -164,4 +170,17 @@ ScopedShapedBuffer::~ScopedShapedBuffer() { } } +std::unique_ptr ScopedShapedBuffer::release() { + auto shaped_buffer = + MakeUnique(shape(), platform(), device_ordinal()); + + *shaped_buffer->mutable_buffers() = buffers(); + *shaped_buffer->mutable_shape_index_to_buffer_entry() = + shape_index_to_buffer_entry(); + + clear(); + + return shaped_buffer; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h index aa3b932c4ef612d7a69de0aa1573ba3945666ed7..e5ea06fb136fa714eab0f340f98b7191a4c5caa3 100644 --- a/tensorflow/compiler/xla/service/shaped_buffer.h +++ b/tensorflow/compiler/xla/service/shaped_buffer.h @@ -33,12 +33,6 @@ namespace xla { // XLA client running in the same process as the service (LocalClient), class ShapedBuffer { public: - // Creates a ShapedBuffer of arbitrary shape. All buffer pointers - // (DeviceMemoryBase) in the returned ShapedBuffer are initialized to null. - static StatusOr> MakeShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal); - // Convenience method which creates a ShapedBuffer of array shape (not a // tuple). Its single buffer pointer is set to the given value "buffer". The // given buffer must be large enough to store the given shape as given by @@ -47,16 +41,9 @@ class ShapedBuffer { const Shape& shape, const perftools::gputools::Platform* platform, int device_ordinal, const perftools::gputools::DeviceMemoryBase& buffer); - // Convenience method which creates a ShapedBuffer of a non-nested tuple. The - // buffer pointers in the return ShapedBuffer are set to the given - // "buffers". The size of buffers must match the number of elements in the - // tuple shape and be large enough to store their respective shape as given by - // ShapeUtil::ByteSizeOf. - static StatusOr> MakeUnnestedTupleShapedBuffer( - const Shape& shape, const perftools::gputools::Platform* platform, - int device_ordinal, - const tensorflow::gtl::ArraySlice - buffers); + ShapedBuffer(const Shape& shape, + const perftools::gputools::Platform* platform, + int device_ordinal); const Shape& shape() const { return shape_; } const perftools::gputools::Platform* platform() const { return platform_; } @@ -85,14 +72,19 @@ class ShapedBuffer { return &shape_index_to_buffer_entry_; } - protected: - ShapedBuffer(const Shape& shape, - const perftools::gputools::Platform* platform, - int device_ordinal); + // Set all device memory pointers in the object to null. + void clear(); + protected: // The shape of the device buffer with layout. const Shape shape_; + // The platform the memory is allocated on. + const perftools::gputools::Platform* platform_; + + // The device the memory is allocated on. + const int device_ordinal_; + // The list of DeviceMemoryBase pointers representing this shape. // Note that there can be a many to one relationship between tuple elements // and buffers. To account for this, shape_index_to_buffer_entry_ allows us @@ -101,12 +93,6 @@ class ShapedBuffer { // The tree of indices into buffers_. ShapeTree shape_index_to_buffer_entry_; - - // The platform the memory is allocated on. - const perftools::gputools::Platform* platform_; - - // The device the memory is allocated on. - const int device_ordinal_; }; // ShapedBuffer derived class which allocates all internal buffers on @@ -114,14 +100,31 @@ class ShapedBuffer { // destructed. class ScopedShapedBuffer : public ShapedBuffer { public: - // Return a new ScopedShapedBuffer of an arbitrary shape. All buffers in the - // ScopedShapedBuffers are automatically allocated to exactly the size of - // their respective array shape. - static StatusOr> MakeScopedShapedBuffer( + // Return a newly allocated ScopedShapedBuffer of an arbitrary shape. Array + // buffers (leaves in the shape) are allocated and uninitialized. Tuple + // buffers (if any) are allocated and initialized to the backend-specific + // representation of an array of pointers to the tuple elements. + static StatusOr> Allocate( const Shape& shape, DeviceMemoryAllocator* allocator, int device_ordinal); + // Takes a ShapedBuffer and returns a ScopedShapedBuffer which manages the + // deallocation of the device memory held in the shaped buffer. All device + // memory pointers in the given ShapedBuffer are set to null. + static StatusOr> MakeScoped( + ShapedBuffer* shaped_buffer, DeviceMemoryAllocator* allocator); + + // Return the allocator used to allocate the device memory held in this + // ScopedShapedBuffer. + DeviceMemoryAllocator* memory_allocator() const { return allocator_; } + + // Release all device memory owned by this ScopedShapedBuffer and return the + // device memory pointers in the form of a ShapedBuffer. Device memory + // pointers in this ScopedShapedBuffer object are set to null. This method is + // analogous to std::unique_ptr::release(). + std::unique_ptr release(); + // All buffers in the shape are deallocated on destruction. - ~ScopedShapedBuffer(); + virtual ~ScopedShapedBuffer(); protected: ScopedShapedBuffer(const Shape& shape, DeviceMemoryAllocator* allocator, diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h index c79ffa9cd73950b1653f72b1c6286346f76c10fb..f63d91604cf40edfae98b56a8bacdbded697ffc3 100644 --- a/tensorflow/compiler/xla/service/transfer_manager.h +++ b/tensorflow/compiler/xla/service/transfer_manager.h @@ -97,6 +97,16 @@ class TransferManager { const perftools::gputools::DeviceMemoryBase& source, const Shape& shape) = 0; + // Writes the given device-memory pointers in 'elements' to the given region + // to construct a tuple in the platform-specific tuple representation. This + // can handle nested tuples as well. In the nested case, the element + // DeviceMemoryBase points to another array of pointers on the device. + virtual Status WriteTuplePointersToDevice( + perftools::gputools::StreamExecutor* executor, + tensorflow::gtl::ArraySlice + elements, + const Shape& shape, perftools::gputools::DeviceMemoryBase* region) = 0; + // Returns all buffer pointers that the tuple `source` refers to. Unlike // ShallowCopyTupleFromDevice, this function gather buffer pointers in nested // tuples as well. Also, the returned DeviceMemoryBase objects are diff --git a/tensorflow/compiler/xla/service/transfer_manager_test.cc b/tensorflow/compiler/xla/service/transfer_manager_test.cc index 29ecef9510cfe6b8764c2e5fe1216255ca1dc983..c25a0861e9b90bc0f2cde43933e14204aa4e3598 100644 --- a/tensorflow/compiler/xla/service/transfer_manager_test.cc +++ b/tensorflow/compiler/xla/service/transfer_manager_test.cc @@ -37,7 +37,9 @@ namespace { class CpuTransferManagerTest : public ::testing::Test { protected: - CpuTransferManagerTest() : transfer_manager_(se::host::kHostPlatformId) { + CpuTransferManagerTest() + : transfer_manager_(se::host::kHostPlatformId, + /*pointer_size=*/sizeof(void*)) { se::Platform* platform = se::MultiPlatformManager::PlatformWithId(se::host::kHostPlatformId) .ValueOrDie(); diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc index d668c812f4f9d119409c8a5147543d329e011df8..8c2640adf52f10c387e7a9c09c0d73a09c054919 100644 --- a/tensorflow/compiler/xla/service/transpose_folding.cc +++ b/tensorflow/compiler/xla/service/transpose_folding.cc @@ -58,14 +58,32 @@ TransposeFolding::OperandIndices CanFoldOperandsIntoConvolution( return {}; } - // We only support folding the RHS. - const int64 kRhsOperandIndex = 1; - auto& operand = *convolution.operand(kRhsOperandIndex); - if (operand.opcode() == HloOpcode::kTranspose && operand.user_count() == 1) { - return transposable_conv_operands(convolution, {kRhsOperandIndex}); + const ConvolutionDimensionNumbers& dnums = + convolution.convolution_dimension_numbers(); + + TransposeFolding::OperandIndices operand_set; + for (int64 i = 0; i < convolution.operand_count(); ++i) { + auto& operand = *convolution.operand(i); + if (operand.opcode() == HloOpcode::kTranspose && + operand.user_count() == 1) { + const auto& transpose_dimensions = operand.dimensions(); + // We can transpose the LHS so long as it doesn't move around spatial + // dimensions because ConvolutionDimensionNumbers doesn't have different + // fields for input and output spatial dimensions. + if (i == 0 && + std::any_of(dnums.spatial_dimensions().begin(), + dnums.spatial_dimensions().end(), + [&](const int64 spatial_dimension) { + return transpose_dimensions[spatial_dimension] != + spatial_dimension; + })) { + continue; + } + operand_set.push_back(i); + } } - return {}; + return transposable_conv_operands(convolution, operand_set); } using InstructionOperandsPair = @@ -98,40 +116,61 @@ bool FoldTransposeIntoDot(InstructionOperandsPair pair) { // Returns whether the module is changed. bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) { auto& convolution = *pair.first; - - // We only support fusing the RHS transpose into convolution. - // - // ConvolutionDimensionNumbers doesn't make enough of a distinction between - // the output and the activations. - // - // TODO(b/37125184): Support transposing the LHS too. - if (pair.second.size() != 1 || pair.second.front() != 1) { - return false; - } + auto& operand_indices = pair.second; const ConvolutionDimensionNumbers& dnums = convolution.convolution_dimension_numbers(); - HloInstruction& transpose = *convolution.mutable_operand(1); - CHECK_EQ(transpose.opcode(), HloOpcode::kTranspose); - const auto& transpose_dimensions = transpose.dimensions(); - HloInstruction& transpose_operand = *transpose.mutable_operand(0); - - // Everything remains the same except for the kernel dimension numbers. We - // need to apply the transpose permutation to the original shape to figure out - // what the new logical dimensions are. ConvolutionDimensionNumbers new_dnums = dnums; - new_dnums.set_kernel_input_feature_dimension( - transpose_dimensions[dnums.kernel_input_feature_dimension()]); - new_dnums.set_kernel_output_feature_dimension( - transpose_dimensions[dnums.kernel_output_feature_dimension()]); - for (auto& kernel_spatial_dimension : - *new_dnums.mutable_kernel_spatial_dimensions()) { - kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + + HloInstruction* new_lhs; + const int64 kLhsIdx = 0; + if (std::find(operand_indices.begin(), operand_indices.end(), kLhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kLhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the input/output dimension + // numbers. We need to apply the transpose permutation to the original shape + // to figure out what the new logical dimensions are. + new_dnums.set_input_batch_dimension( + transpose_dimensions[dnums.input_batch_dimension()]); + new_dnums.set_input_feature_dimension( + transpose_dimensions[dnums.input_feature_dimension()]); + for (const auto& spatial_dimension : dnums.spatial_dimensions()) { + CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]); + } + new_lhs = &transpose_operand; + } else { + new_lhs = convolution.mutable_operand(kLhsIdx); + } + + HloInstruction* new_rhs; + const int64 kRhsIdx = 1; + if (std::find(operand_indices.begin(), operand_indices.end(), kRhsIdx) != + operand_indices.end()) { + HloInstruction& transpose = *convolution.mutable_operand(kRhsIdx); + const auto& transpose_dimensions = transpose.dimensions(); + HloInstruction& transpose_operand = *transpose.mutable_operand(0); + + // Everything remains the same except for the kernel dimension numbers. We + // need to apply the transpose permutation to the original shape to figure + // out what the new logical dimensions are. + new_dnums.set_kernel_input_feature_dimension( + transpose_dimensions[dnums.kernel_input_feature_dimension()]); + new_dnums.set_kernel_output_feature_dimension( + transpose_dimensions[dnums.kernel_output_feature_dimension()]); + for (auto& kernel_spatial_dimension : + *new_dnums.mutable_kernel_spatial_dimensions()) { + kernel_spatial_dimension = transpose_dimensions[kernel_spatial_dimension]; + } + new_rhs = &transpose_operand; + } else { + new_rhs = convolution.mutable_operand(kRhsIdx); } auto new_conv = HloInstruction::CreateConvolve( - convolution.shape(), convolution.mutable_operand(0), &transpose_operand, - convolution.window(), new_dnums); + convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums); TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction( &convolution, std::move(new_conv))); @@ -171,14 +210,7 @@ StatusOr TransposeFolding::Run(HloModule* module) { return tensorflow::Status::OK(); }; - std::vector computations; - for (auto& computation : module->computations()) { - if (computation->IsFusionComputation()) { - continue; - } - computations.push_back(computation.get()); - } - for (auto& comp : computations) { + for (auto* comp : module->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(comp->Accept(visit_fn)); } diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc index a5be4ab7ed4e56f27c89bfb270d0c032ded8161f..00462f9be1e9beb2f2694060ebfaa70b0b9dd4a0 100644 --- a/tensorflow/compiler/xla/service/transpose_folding_test.cc +++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc @@ -74,10 +74,9 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -87,7 +86,7 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) { // The fusion instruction should contain two parameters, one transpose and // one dot. - EXPECT_EQ(4, fusion->fused_instructions().size()); + EXPECT_EQ(4, fusion->fused_instruction_count()); } TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { @@ -114,7 +113,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { module.AddEntryComputation(builder.Build(dot)); FoldTranspose(&module); - for (auto& instruction : entry_computation->instructions()) { + for (auto* instruction : entry_computation->instructions()) { if (instruction->opcode() == HloOpcode::kFusion) { CHECK_EQ(2, instruction->operand_count()); EXPECT_EQ(const0, instruction->operand(0)); @@ -125,7 +124,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) { // The created fusion instruction should contain two parameters, two // transposes (one for each parameter) and one dot. EXPECT_EQ(5, - entry_computation->root_instruction()->fused_instructions().size()); + entry_computation->root_instruction()->fused_instruction_count()); } TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { @@ -156,7 +155,7 @@ TEST_F(TransposeFoldingTest, FuseDotWithConstantOperands) { ::testing::UnorderedElementsAre(const1, const2, const3)); // The callee should contain 3 parameters and 3 binary operators. - EXPECT_EQ(6, callee_computation->instructions().size()); + EXPECT_EQ(6, callee_computation->instruction_count()); } TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { @@ -184,10 +183,9 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { FoldTranspose(&module); // Instructions after folding: x, y, and the fusion. - std::unordered_set instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(call)) @@ -200,7 +198,7 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) { // The fusion instruction should contain two parameters, one transpose and // one dot. - EXPECT_EQ(4, fusion->fused_instructions().size()); + EXPECT_EQ(4, fusion->fused_instruction_count()); } // Test that a two dimension swap of the kernel gets folded into convolution. @@ -239,10 +237,9 @@ TEST_F(TransposeFoldingTest, FoldConvDimSwapTransposeRhs) { FoldTranspose(&module); // Instructions after folding: x, y, and the convolution. - std::unordered_set instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -293,10 +290,9 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { FoldTranspose(&module); // Instructions after folding: x, y, and the convolution. - std::unordered_set instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; CHECK_EQ(1, instruction_set.size()) @@ -317,8 +313,7 @@ TEST_F(TransposeFoldingTest, FoldConvComplexTransposeRhs) { new_conv->convolution_dimension_numbers().kernel_spatial_dimensions(1)); } -// Test that a transpose of the activations does not get folded into -// convolution. +// Test that a transpose of the activations gets folded into convolution. TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { auto builder = HloComputation::Builder("entry_computation"); HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter( @@ -352,19 +347,25 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) { module.AddEntryComputation(builder.Build(conv)); FoldTranspose(&module); - // Instructions after folding: transpose_x, y, and the convolution. - std::unordered_set instruction_set; - for (auto& instruction : entry_computation->instructions()) { - instruction_set.insert(instruction.get()); - } - CHECK_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(transpose_x)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(1, instruction_set.erase(conv)) - << "transpose_x is not in entry_computation."; - CHECK_EQ(0, instruction_set.size()) - << "entry_computation should contain exactly 4 instructions."; + // Instructions after folding: x, y, and the convolution. + std::unordered_set instruction_set( + entry_computation->instructions().begin(), + entry_computation->instructions().end()); + EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation."; + EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation."; + EXPECT_EQ(1, instruction_set.size()) + << "entry_computation should contain exactly 3 instructions."; + HloInstruction* new_conv = *instruction_set.begin(); + EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode()); + EXPECT_EQ(dnums.input_feature_dimension(), + new_conv->convolution_dimension_numbers().input_batch_dimension()); + EXPECT_EQ( + dnums.input_batch_dimension(), + new_conv->convolution_dimension_numbers().input_feature_dimension()); + EXPECT_EQ(dnums.spatial_dimensions(0), + new_conv->convolution_dimension_numbers().spatial_dimensions(0)); + EXPECT_EQ(dnums.spatial_dimensions(1), + new_conv->convolution_dimension_numbers().spatial_dimensions(1)); } } // namespace diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index 9fc288d3017137c8a0741a9a69c7a20396ce4af1..df537bd7c15a1f15ed77ca9be6ce70fbfd2e63be 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -137,15 +137,12 @@ Status TuplePointsToAnalysis::Analyze() { logical_buffer_aliases_.resize( logical_buffer_analysis_->num_logical_buffers()); - for (auto& computation : module_->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); // Run points-to analysis on fusion instructions in 'computation'. - for (auto& instruction : computation->instructions()) { + for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } @@ -160,21 +157,21 @@ Status TuplePointsToAnalysis::Analyze() { return Status::OK(); } -Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases( - const std::list>& instructions) { - for (auto& instruction : instructions) { - PerInstruction* pi = PerInst(instruction.get()); +Status TuplePointsToAnalysis::PopulateDefinedBuffersAndAliases(const decltype( + std::declval().instructions())& instructions) { + for (auto* instruction : instructions) { + PerInstruction* pi = PerInst(instruction); TF_RETURN_IF_ERROR(GatherBuffersDefinedByInstruction( - instruction.get(), &pi->instruction_defined_buffers)); + instruction, &pi->instruction_defined_buffers)); - const PointsToSet& points_to_set = GetPointsToSet(instruction.get()); + const PointsToSet& points_to_set = GetPointsToSet(instruction); points_to_set.ForEachElement( [this, &instruction]( const ShapeIndex& index, const PointsToSet::BufferList& pointed_to_buffers) { for (const LogicalBuffer* buffer : pointed_to_buffers) { - logical_buffer_aliases_[buffer->id()].emplace_back( - instruction.get(), index); + logical_buffer_aliases_[buffer->id()].emplace_back(instruction, + index); } }); } @@ -203,13 +200,14 @@ Status TuplePointsToAnalysis::DefaultAction(HloInstruction* hlo_instruction) { } Status TuplePointsToAnalysis::HandleGetTupleElement( - HloInstruction* get_tuple_element, HloInstruction* operand) { + HloInstruction* get_tuple_element) { // GetTupleElement forwards a pointer to a particular element of the tuple // operand. int64 element_index = get_tuple_element->tuple_index(); PointsToSet& points_to_set = CreateEmptyPointsToSet(get_tuple_element); - const PointsToSet& operand_points_to_set = *PerInst(operand)->points_to_set; + const PointsToSet& operand_points_to_set = + *PerInst(get_tuple_element->operand(0))->points_to_set; // Copy the points-to set (and tuple sources) at index {element_index} of the // operand to the points-to set for this GetTupleElement instruction. @@ -255,9 +253,8 @@ Status TuplePointsToAnalysis::HandleBitcast(HloInstruction* bitcast) { return Status::OK(); } -Status TuplePointsToAnalysis::HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) { +Status TuplePointsToAnalysis::HandleTuple(HloInstruction* tuple) { + tensorflow::gtl::ArraySlice operands(tuple->operands()); PointsToSet& points_to_set = CreateEmptyPointsToSet(tuple); points_to_set.AddPointedToBuffer( logical_buffer_analysis_->GetBuffer(tuple, /*index=*/{}), @@ -295,10 +292,7 @@ Status TuplePointsToAnalysis::HandleTuple( return Status::OK(); } -Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, - HloInstruction* /*pred*/, - HloInstruction* on_true, - HloInstruction* on_false) { +Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select) { // Select allocates a new buffer and then shallow copies the on_true or // on_false buffer into this new buffer. Which side is chosen cannot be // determined statically so conservatively set the points-to set to the union @@ -306,6 +300,8 @@ Status TuplePointsToAnalysis::HandleSelect(HloInstruction* select, // // First create a copy of the on_true points-to set (and tuple sources), then // add in elements of the on_false points-to set (tuple sources). + auto on_true = select->operand(1); + auto on_false = select->operand(2); PointsToSet& points_to_set = CreateCopiedPointsToSet(select, on_true); const PointsToSet& false_points_to_set = *PerInst(on_false)->points_to_set; points_to_set.ForEachMutableElement( @@ -452,20 +448,17 @@ PointsToSet& TuplePointsToAnalysis::CreateCopiedPointsToSet( string TuplePointsToAnalysis::ToString() const { string output = tensorflow::strings::Printf( "TuplePointsToSet for module %s:\n", module_->name().c_str()); - for (const auto& computation : module_->computations()) { - if (computation->IsFusionComputation()) { - continue; - } + for (const auto* computation : module_->MakeNonfusionComputations()) { const char* entry = - computation.get() == module_->entry_computation() ? "entry " : ""; + computation == module_->entry_computation() ? "entry " : ""; tensorflow::strings::StrAppend(&output, entry, "computation ", computation->name(), ":\n"); for (const HloInstruction* instruction : computation->MakeInstructionPostOrder()) { InstructionToString(instruction, &output); if (instruction->opcode() == HloOpcode::kFusion) { - for (auto& fused : instruction->fused_instructions()) { - InstructionToString(fused.get(), &output); + for (auto* fused : instruction->fused_instructions()) { + InstructionToString(fused, &output); } } } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h index 3b3a046e498f0e5fdd6c0a18caadab856f5db676..e6157a1ed11b5df24458fe820a4e0e329eb86ae4 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h @@ -44,7 +44,7 @@ namespace xla { // A class describing the source(s) of the Buffer(s) contained in the output of // a particular HLO instruction. The structure of PointsToSet mirrors the -// structure of the instruction's shape which may be an arbitrary tree (eg, a +// structure of the instruction's shape, which may be an arbitrary tree (eg, a // nested tuple). Each node in this tree corresponds to a single buffer in the // instruction's output and contains the set of Buffers which might define // the corresponding buffer. @@ -148,7 +148,7 @@ class PointsToSet { ShapeTree tree_; // PointsToSet contains references (const LogicalBuffer*) to elements within - // TuplePointsToAnalysis so disable copying. + // TuplePointsToAnalysis, so disable copying. TF_DISALLOW_COPY_AND_ASSIGN(PointsToSet); }; @@ -247,16 +247,11 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status VerifyBuffer(const LogicalBuffer& buffer) const; Status DefaultAction(HloInstruction* hlo_instruction) override; - Status HandleTuple( - HloInstruction* tuple, - tensorflow::gtl::ArraySlice operands) override; - Status HandleGetTupleElement(HloInstruction* get_tuple_element, - HloInstruction* operand) override; + Status HandleTuple(HloInstruction* tuple) override; + Status HandleGetTupleElement(HloInstruction* get_tuple_element) override; Status HandleBitcast(HloInstruction* bitcast) override; Status HandleCopy(HloInstruction* copy) override; - Status HandleSelect(HloInstruction* select, HloInstruction* pred, - HloInstruction* on_true, - HloInstruction* on_false) override; + Status HandleSelect(HloInstruction* select) override; string ToString() const; @@ -272,11 +267,9 @@ class TuplePointsToAnalysis : public DfsHloVisitorWithDefault { Status Analyze(); // Populates instruction-defined buffers and aliases for each instruction - // in 'instructions'. The parameter 'instructions' is passed in a form - // common to how both HloComputation, and fusion instructions maintain a - // list of instructions. - Status PopulateDefinedBuffersAndAliases( - const std::list>& instructions); + // in 'instructions'. + Status PopulateDefinedBuffersAndAliases(const decltype( + std::declval().instructions())& instructions); // Creates an empty PointsToSet in the points_to_ map for the given // instruction. diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc index dfa94db5dbb2fbbc9f2930c38f8f7cd18df23abb..694ed57fa24d59bd0a28c7bb9b67af8165e90363 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis_test.cc @@ -661,13 +661,12 @@ class FusionPointsToAnalysisTest : public TuplePointsToAnalysisTest { HloInstruction* operand) { auto it = std::find_if( fusion->fused_instructions().begin(), - fusion->fused_instructions().end(), - [=](const std::unique_ptr& fused) { + fusion->fused_instructions().end(), [=](const HloInstruction* fused) { return fused->opcode() == HloOpcode::kParameter && fusion->operand(fused->parameter_number()) == operand; }); CHECK(it != fusion->fused_instructions().end()); - return (*it).get(); + return *it; } // Returns all users of 'fusion_paran' at 'tuple_index'. diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..113c2e2bd9f73a2b0c783103d7f2da9534bc97c3 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc @@ -0,0 +1,129 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" + +#include + +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { + +StatusOr TupleSimplifier::Run(HloModule* module) { + // Initially add all GTE and Tuple instructions to the worklist. + std::queue worklist; + for (auto* computation : module->computations()) { + for (auto* instruction : computation->instructions()) { + if (instruction->opcode() == HloOpcode::kTuple || + instruction->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(instruction); + } + } + } + + bool changed = false; + while (!worklist.empty()) { + HloInstruction* instruction = worklist.front(); + worklist.pop(); + + if (instruction->user_count() == 0 && + instruction != instruction->parent()->root_instruction()) { + // Tuple simplification works by replacing users of optimized away + // instructions with a simpler form. If there is no user of the + // instruction (including being the root), then there is nothing to do. + continue; + } + + if (instruction->opcode() == HloOpcode::kTuple) { + // Collapse the following structure into just 'Tuple-shaped Op': + // + // Tuple-shaped Op + // | + // +-----+-----+ + // | | | + // GTE GTE GTE + // | | | + // +-----+-----+ + // | + // Tuple + // + HloInstruction* top_tuple = nullptr; + bool can_simplify = true; + for (int64 operand_number = 0; + operand_number < instruction->operand_count(); ++operand_number) { + HloInstruction* operand = instruction->mutable_operand(operand_number); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->tuple_index() != operand_number) { + can_simplify = false; + break; + } + + if (top_tuple == nullptr) { + top_tuple = operand->mutable_operand(0); + if (!ShapeUtil::Compatible(top_tuple->shape(), + instruction->shape())) { + can_simplify = false; + break; + } + } else if (top_tuple != operand->operand(0)) { + can_simplify = false; + break; + } + } + if (can_simplify && top_tuple != nullptr) { + changed = true; + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(top_tuple)); + // No need to add anything to the worklist. + } + } else { + CHECK_EQ(instruction->opcode(), HloOpcode::kGetTupleElement); + // If possible replace a GTE with the operation which produces the + // element. For example, replace uses of GTE with below with just 'Op' + // (assuming 'Op' is at the index of the GTE instruction): + // + // ... Op ... + // \ | / + // Tuple + // | + // GTE + if (instruction->operand(0)->opcode() == HloOpcode::kTuple) { + changed = true; + HloInstruction* element_source = + instruction->mutable_operand(0)->mutable_operand( + instruction->tuple_index()); + TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source)); + for (HloInstruction* user : element_source->users()) { + if (user->opcode() == HloOpcode::kTuple || + user->opcode() == HloOpcode::kGetTupleElement) { + worklist.push(user); + } + } + } + } + } + + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.h b/tensorflow/compiler/xla/service/tuple_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..e5e9b10b5bf3f452d1bfec476b8d5c7d74c4f4e8 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_simplifier.h @@ -0,0 +1,41 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_SIMPLIFIER_H_ +#define TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_SIMPLIFIER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" + +namespace xla { + +// A pass which simplifies patterns of Tuple and GetTupleElement instructions in +// the module. +class TupleSimplifier : public HloPassInterface { + public: + TupleSimplifier() {} + ~TupleSimplifier() override {} + tensorflow::StringPiece name() const override { return "tuple-simplifier"; } + + // Run tuple simplification on the given computation. Returns whether the + // computation was changed. + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_SERVICE_TUPLE_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/tuple_simplifier_test.cc b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..ca9ae91281fce5ee061d066fc3e538dbbc09f6b3 --- /dev/null +++ b/tensorflow/compiler/xla/service/tuple_simplifier_test.cc @@ -0,0 +1,215 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/tuple_simplifier.h" + +#include +#include + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace op = xla::testing::opcode_matchers; + +namespace xla { +namespace { + +class TupleSimplifierTest : public HloTestBase { + protected: + void Run(HloModule* module, bool change_expected) { + TupleSimplifier simplifier; + auto changed_status = simplifier.Run(module); + TF_ASSERT_OK(changed_status.status()); + EXPECT_EQ(change_expected, changed_status.ValueOrDie()); + } + + const Shape scalar_shape_ = ShapeUtil::MakeShape(F32, {}); + const Shape tuple_shape_ = ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShape(F32, {}), ShapeUtil::MakeShape(F32, {}), + ShapeUtil::MakeShape(F32, {})}); +}; + +TEST_F(TupleSimplifierTest, TupleOfParameters) { + // A Tuple constructed of a bunch of parameters should not be changed. + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, scalar_shape_, "param2")); + builder.AddInstruction(HloInstruction::CreateTuple({param0, param1, param2})); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Run(module.get(), /*change_expected=*/false); +} + +TEST_F(TupleSimplifierTest, GteOfTupleOfParameter) { + // A GTE of a tuple parameter should not be changed. + HloComputation::Builder builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, param, 1)); + auto module = CreateNewModule(); + module->AddEntryComputation(builder.Build()); + + Run(module.get(), /*change_expected=*/false); +} + +TEST_F(TupleSimplifierTest, GteOfTuple) { + // A GTE of a Tuple should be short-circuited. + HloComputation::Builder builder(TestName()); + HloInstruction* param0 = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param0")); + HloInstruction* param1 = builder.AddInstruction( + HloInstruction::CreateParameter(1, scalar_shape_, "param1")); + HloInstruction* param2 = builder.AddInstruction( + HloInstruction::CreateParameter(2, scalar_shape_, "param2")); + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({param0, param1, param2})); + HloInstruction* gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), gte); + + Run(module.get(), /*change_expected=*/true); + + EXPECT_THAT(computation->root_instruction(), param1); +} + +TEST_F(TupleSimplifierTest, GteOfTupleChain) { + // Verify a chain of GTE/Tuple instructions is collapsed. + HloComputation::Builder builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + + const int kChainLength = 10; + HloInstruction* element = param; + for (int i = 0; i < kChainLength; ++i) { + HloInstruction* tuple = builder.AddInstruction( + HloInstruction::CreateTuple({element, element, element})); + element = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple, 1)); + } + builder.AddInstruction( + HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kNegate, element)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), + op::Negate(op::GetTupleElement(op::Tuple()))); + + Run(module.get(), /*change_expected=*/true); + + EXPECT_THAT(computation->root_instruction(), op::Negate(op::Parameter())); +} + +TEST_F(TupleSimplifierTest, NestedGteOfTuples) { + // Verify a nesting of GTE/Tuple instructions is collapsed. Tuples are nested + // to some depth with a chain of Tuple instructions, then extracted with a + // chain of GTE instructions. + HloComputation::Builder builder(TestName()); + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, scalar_shape_, "param")); + + const int kNestingDepth = 5; + HloInstruction* nested_tuple = param; + for (int i = 0; i < kNestingDepth; ++i) { + nested_tuple = builder.AddInstruction( + HloInstruction::CreateTuple({nested_tuple, nested_tuple})); + } + + HloInstruction* element = nested_tuple; + for (int i = 0; i < kNestingDepth; ++i) { + element = builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::GetTupleElementShape(element->shape(), 0), element, 0)); + } + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), element); + + Run(module.get(), /*change_expected=*/true); + + EXPECT_THAT(computation->root_instruction(), param); +} + +TEST_F(TupleSimplifierTest, TupleOfGteInstructions) { + // Verify that a tuple constructed of GTE instructions operating on the same + // tuple are collapsed. + HloComputation::Builder builder(TestName()); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1)); + HloInstruction* gte2 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 2)); + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1, gte2})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), tuple); + + Run(module.get(), /*change_expected=*/true); + + EXPECT_THAT(computation->root_instruction(), tuple_param); +} + +TEST_F(TupleSimplifierTest, IncompatibleTuples) { + // Verify that a tuple->GTE->tuple construct is not simplified if the input + // and output tuple are not compatible shapes. + HloComputation::Builder builder(TestName()); + HloInstruction* tuple_param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape_, "param")); + HloInstruction* gte0 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 0)); + HloInstruction* gte1 = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_shape_, tuple_param, 1)); + // Output tuple has only two elements. Parameter tuple has three elements so + // simplification is not possible. + HloInstruction* tuple = + builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte1})); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_THAT(computation->root_instruction(), tuple); + + Run(module.get(), /*change_expected=*/false); + + EXPECT_THAT(computation->root_instruction(), tuple); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index ac7c31bf68879c6fe3e96916e63105019340abaa..e9d182509b5356d32b667b7921e2843d30faeb9b 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include #include +#include #include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/literal_util.h" @@ -54,14 +55,18 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { return HloOpcode::kExp; case UNOP_FLOOR: return HloOpcode::kFloor; + case UNOP_IMAG: + return HloOpcode::kImag; case UNOP_IS_FINITE: return HloOpcode::kIsFinite; case UNOP_LOG: return HloOpcode::kLog; - case UNOP_LOGICAL_NOT: - return HloOpcode::kLogicalNot; + case UNOP_NOT: + return HloOpcode::kNot; case UNOP_NEGATE: return HloOpcode::kNegate; + case UNOP_REAL: + return HloOpcode::kReal; case UNOP_ROUND_NEAREST_AFZ: return HloOpcode::kRoundNearestAfz; case UNOP_SIGN: @@ -79,6 +84,10 @@ HloOpcode UnaryOperationToHloOpcode(UnaryOperation unop) { HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { switch (binop) { + case BINOP_ATAN2: + return HloOpcode::kAtan2; + case BINOP_COMPLEX: + return HloOpcode::kComplex; case BINOP_DOT: return HloOpcode::kDot; case BINOP_MUL: @@ -87,8 +96,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kAdd; case BINOP_SUB: return HloOpcode::kSubtract; - case BINOP_INDEX: - return HloOpcode::kIndex; case BINOP_DIV: return HloOpcode::kDivide; case BINOP_EQ: @@ -111,10 +118,16 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) { return HloOpcode::kPower; case BINOP_REM: return HloOpcode::kRemainder; - case BINOP_LOGICAL_OR: - return HloOpcode::kLogicalOr; - case BINOP_LOGICAL_AND: - return HloOpcode::kLogicalAnd; + case BINOP_OR: + return HloOpcode::kOr; + case BINOP_AND: + return HloOpcode::kAnd; + case BINOP_SHIFT_LEFT: + return HloOpcode::kShiftLeft; + case BINOP_SHIFT_RIGHT_ARITHMETIC: + return HloOpcode::kShiftRightArithmetic; + case BINOP_SHIFT_RIGHT_LOGICAL: + return HloOpcode::kShiftRightLogical; default: LOG(FATAL) << "unhandled operation " << binop; } @@ -126,8 +139,6 @@ HloOpcode TernaryOperationToHloOpcode(TernaryOperation triop) { return HloOpcode::kClamp; case TRIOP_SELECT: return HloOpcode::kSelect; - case TRIOP_UPDATE: - return HloOpcode::kUpdate; default: LOG(FATAL) << "unhandled operation " << triop; } @@ -421,7 +432,8 @@ StatusOr UserComputation::AddMapInstruction( to_apply_computation.ComputeProgramShape(to_apply_version)); TF_ASSIGN_OR_RETURN( Shape inferred_shape, - ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape)); + ShapeInference::InferMapShape(operand_shapes, *to_apply_program_shape, + AsInt64Slice(map_request.dimensions()))); ComputationDataHandle handle = CreateComputationDataHandle(); @@ -1301,20 +1313,19 @@ Status UserComputation::SetOpMetadata(const ComputationDataHandle& handle, return Status::OK(); } -Status UserComputation::SetOpDeviceAssignment( - const ComputationDataHandle& handle, - const OpDeviceAssignment& device_assignment) { +Status UserComputation::SetOpSharding(const ComputationDataHandle& handle, + const OpSharding& sharding) { tensorflow::mutex_lock lock(mutex_); int64 handle_value = handle.handle(); if (session_computation_.requests().count(handle_value) == 0) { - return InvalidArgument("Invalid handle in SetOpDeviceAssignment (%lld)", + return InvalidArgument("Invalid handle in SetOpSharding (%lld)", handle_value); } *session_computation_.mutable_requests() ->at(handle_value) .mutable_request() - ->mutable_device_assignment() = device_assignment; + ->mutable_sharding() = sharding; return Status::OK(); } @@ -1471,14 +1482,15 @@ UserComputation::ComputeProgramShape( namespace { -// A visitor which checks whether an operation is a compile-time constant. That -// is, the operation does not depend on any parameter instructions. The visitor -// walks the computation starting at a given operation and sets is_constant to -// false iff a parameter or RNG operation is encountered. -void ConstantVisitor(const SessionComputation& session_computation, - const ComputationDataHandle& handle, - std::set* visited, bool* is_constant) { - if (visited->count(handle.handle()) != 0 || !*is_constant) { +// A visitor which checks whether an operation is pure functional meaning that +// it doesn't depend on any parameter with an index higher then num_parameters. +// The visitor walks the computation starting at a given operation and sets +// is_functional to false iff a parameter or RNG operation is encountered. +void PureFunctionalVisitor(const SessionComputation& session_computation, + const ComputationDataHandle& handle, + int64 num_parameters, std::set* visited, + bool* is_functional) { + if (visited->count(handle.handle()) != 0 || !*is_functional) { return; } @@ -1486,7 +1498,7 @@ void ConstantVisitor(const SessionComputation& session_computation, session_computation.requests().at(handle.handle()); switch (request.request().op_case()) { case OpRequest::kRngRequest: - *is_constant = false; + *is_functional = false; break; case OpRequest::kConstantRequest: @@ -1495,41 +1507,43 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kGetTupleElementRequest: { const GetTupleElementRequest& get_tuple_element_request = request.request().get_tuple_element_request(); - ConstantVisitor(session_computation, get_tuple_element_request.operand(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + get_tuple_element_request.operand(), num_parameters, + visited, is_functional); break; } case OpRequest::kSliceRequest: { const SliceRequest& slice_request = request.request().slice_request(); - ConstantVisitor(session_computation, slice_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, slice_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicSliceRequest: { const DynamicSliceRequest& dynamic_slice_request = request.request().dynamic_slice_request(); - ConstantVisitor(session_computation, dynamic_slice_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, - dynamic_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } case OpRequest::kDynamicUpdateSliceRequest: { const DynamicUpdateSliceRequest& dynamic_update_slice_request = request.request().dynamic_update_slice_request(); - ConstantVisitor(session_computation, - dynamic_update_slice_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.update(), visited, - is_constant); - ConstantVisitor(session_computation, - dynamic_update_slice_request.start_indices(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.update(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + dynamic_update_slice_request.start_indices(), + num_parameters, visited, is_functional); break; } @@ -1538,7 +1552,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().concatenate_request(); for (const ComputationDataHandle& handle : concatenate_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1546,61 +1561,63 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kConvolveRequest: { const ConvolveRequest& convolve_request = request.request().convolve_request(); - ConstantVisitor(session_computation, convolve_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, convolve_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convolve_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, convolve_request.rhs(), + num_parameters, visited, is_functional); break; } case OpRequest::kCrossReplicaSumRequest: { // TODO(b/33009255): Implmement constant folding for cross replica sum. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kInfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kOutfeedRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself, // so we conservatively say that computations containing the Call op - // cannot be constant. We cannot set is_constant=false in other similar + // cannot be constant. We cannot set is_functional=false in other similar // cases since we're already relying on IsConstant to return true. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kCustomCallRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kSendRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kRecvRequest: { - *is_constant = false; + *is_functional = false; break; } case OpRequest::kMapRequest: { const MapRequest& map_request = request.request().map_request(); for (const ComputationDataHandle& handle : map_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } // TODO(b/32495713): We aren't checking the to_apply computation itself. break; @@ -1608,10 +1625,10 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceRequest: { const ReduceRequest& reduce_request = request.request().reduce_request(); - ConstantVisitor(session_computation, reduce_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, reduce_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reduce_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, reduce_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1619,10 +1636,12 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kReduceWindowRequest: { const ReduceWindowRequest& reduce_window_request = request.request().reduce_window_request(); - ConstantVisitor(session_computation, reduce_window_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, reduce_window_request.init_value(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + reduce_window_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + reduce_window_request.init_value(), num_parameters, + visited, is_functional); // TODO(b/32495713): We aren't checking the to_apply computation itself. break; } @@ -1630,13 +1649,15 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kSelectAndScatterRequest: { const SelectAndScatterRequest& select_and_scatter_request = request.request().select_and_scatter_request(); - ConstantVisitor(session_computation, select_and_scatter_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, select_and_scatter_request.source(), - visited, is_constant); - ConstantVisitor(session_computation, - select_and_scatter_request.init_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.source(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + select_and_scatter_request.init_value(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the select and scatter // computations themselves. break; @@ -1645,76 +1666,80 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kBroadcastRequest: { const BroadcastRequest& broadcast_request = request.request().broadcast_request(); - ConstantVisitor(session_computation, broadcast_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, broadcast_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReshapeRequest: { const ReshapeRequest& reshape_request = request.request().reshape_request(); - ConstantVisitor(session_computation, reshape_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reshape_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kReverseRequest: { const ReverseRequest& reverse_request = request.request().reverse_request(); - ConstantVisitor(session_computation, reverse_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, reverse_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kPadRequest: { const PadRequest& pad_request = request.request().pad_request(); - ConstantVisitor(session_computation, pad_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, pad_request.padding_value(), visited, - is_constant); + PureFunctionalVisitor(session_computation, pad_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, pad_request.padding_value(), + num_parameters, visited, is_functional); break; } case OpRequest::kParameterRequest: { - *is_constant = false; + const ParameterRequest& parameter_request = + request.request().parameter_request(); + if (parameter_request.parameter() >= num_parameters) { + *is_functional = false; + } break; } case OpRequest::kConvertRequest: { const ConvertRequest& convert_request = request.request().convert_request(); - ConstantVisitor(session_computation, convert_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, convert_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kWhileRequest: { const WhileRequest& while_request = request.request().while_request(); - ConstantVisitor(session_computation, while_request.init(), visited, - is_constant); + PureFunctionalVisitor(session_computation, while_request.init(), + num_parameters, visited, is_functional); // TODO(b/32495713): We aren't checking the condition and body // computations themselves. - *is_constant = false; + *is_functional = false; break; } case OpRequest::kTernaryOpRequest: { const TernaryOpRequest& ternary_op_request = request.request().ternary_op_request(); - ConstantVisitor(session_computation, ternary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.rhs(), visited, - is_constant); - ConstantVisitor(session_computation, ternary_op_request.ehs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, ternary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.rhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, ternary_op_request.ehs(), + num_parameters, visited, is_functional); break; } case OpRequest::kTransposeRequest: { const TransposeRequest& transpose_request = request.request().transpose_request(); - ConstantVisitor(session_computation, transpose_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, transpose_request.operand(), + num_parameters, visited, is_functional); break; } @@ -1723,7 +1748,8 @@ void ConstantVisitor(const SessionComputation& session_computation, request.request().variadic_op_request(); for (const ComputationDataHandle& handle : variadic_op_request.operands()) { - ConstantVisitor(session_computation, handle, visited, is_constant); + PureFunctionalVisitor(session_computation, handle, num_parameters, + visited, is_functional); } break; } @@ -1731,67 +1757,74 @@ void ConstantVisitor(const SessionComputation& session_computation, case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); - ConstantVisitor(session_computation, unary_op_request.operand(), visited, - is_constant); + PureFunctionalVisitor(session_computation, unary_op_request.operand(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormTrainingRequest: { const BatchNormTrainingRequest& batch_norm_training_request = request.request().batch_norm_training_request(); - ConstantVisitor(session_computation, - batch_norm_training_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_training_request.offset(), - visited, is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_training_request.offset(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormInferenceRequest: { const BatchNormInferenceRequest& batch_norm_inference_request = request.request().batch_norm_inference_request(); - ConstantVisitor(session_computation, - batch_norm_inference_request.operand(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.offset(), visited, - is_constant); - ConstantVisitor(session_computation, batch_norm_inference_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_inference_request.variance(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.operand(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.scale(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.offset(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.mean(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_inference_request.variance(), + num_parameters, visited, is_functional); break; } case OpRequest::kBatchNormGradRequest: { const BatchNormGradRequest& batch_norm_grad_request = request.request().batch_norm_grad_request(); - ConstantVisitor(session_computation, batch_norm_grad_request.operand(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.scale(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.mean(), - visited, is_constant); - ConstantVisitor(session_computation, batch_norm_grad_request.variance(), - visited, is_constant); - ConstantVisitor(session_computation, - batch_norm_grad_request.grad_output(), visited, - is_constant); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.operand(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.scale(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, batch_norm_grad_request.mean(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.variance(), num_parameters, + visited, is_functional); + PureFunctionalVisitor(session_computation, + batch_norm_grad_request.grad_output(), + num_parameters, visited, is_functional); break; } case OpRequest::kBinaryOpRequest: { const BinaryOpRequest& binary_op_request = request.request().binary_op_request(); - ConstantVisitor(session_computation, binary_op_request.lhs(), visited, - is_constant); - ConstantVisitor(session_computation, binary_op_request.rhs(), visited, - is_constant); + PureFunctionalVisitor(session_computation, binary_op_request.lhs(), + num_parameters, visited, is_functional); + PureFunctionalVisitor(session_computation, binary_op_request.rhs(), + num_parameters, visited, is_functional); break; } @@ -1806,8 +1839,8 @@ void ConstantVisitor(const SessionComputation& session_computation, } // namespace -StatusOr UserComputation::IsConstant( - const ComputationDataHandle& handle) { +StatusOr UserComputation::IsConstant(const ComputationDataHandle& handle, + int64 num_parameters) { tensorflow::mutex_lock lock(mutex_); // Verify that the handle is valid. @@ -1818,7 +1851,8 @@ StatusOr UserComputation::IsConstant( bool is_constant = true; std::set visited; - ConstantVisitor(session_computation_, handle, &visited, &is_constant); + PureFunctionalVisitor(session_computation_, handle, num_parameters, &visited, + &is_constant); return is_constant; } @@ -1836,10 +1870,17 @@ UserComputation::GetEmbeddedComputations( XLA_VLOG_LINES(3, session_computation_.DebugString()); std::vector computations; + std::vector sorted_handles; for (const auto& handle_request : session_computation_.requests()) { - int64 handle_value = handle_request.first; + sorted_handles.push_back(handle_request.first); + } + std::sort(sorted_handles.begin(), sorted_handles.end()); + for (int64 handle : sorted_handles) { + const auto& handle_request = session_computation_.requests().find(handle); + CHECK(handle_request != session_computation_.requests().end()); + int64 handle_value = handle_request->first; if (handle_value <= version) { - const OperationRequest& request = handle_request.second; + const OperationRequest& request = handle_request->second; switch (request.request().op_case()) { case OpRequest::kCallRequest: { CHECK_EQ(1, request.embedded_computation_versions_size()); @@ -2495,8 +2536,12 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( operand->shape().element_type(), AsInt64Slice(output_shape.dimensions())); // Do explicit broadcast for scalar. if (ShapeUtil::IsScalar(operand->shape())) { - return hlo_builder_.AddInstruction( + HloInstruction* broadcast = hlo_builder_.AddInstruction( HloInstruction::CreateBroadcast(broadcast_shape, operand, {})); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; } // Do explicit broadcast for degenerate broadcast. std::vector broadcast_dimensions; @@ -2513,9 +2558,17 @@ HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast( ShapeUtil::MakeShape(operand->shape().element_type(), reshaped_dimensions), operand)); + if (operand->has_sharding()) { + reshaped_operand->set_sharding(operand->sharding()); + } // Broadcast 'reshape' up to the larger size. - return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( - broadcast_shape, reshaped_operand, broadcast_dimensions)); + HloInstruction* broadcast = + hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast( + broadcast_shape, reshaped_operand, broadcast_dimensions)); + if (operand->has_sharding()) { + broadcast->set_sharding(operand->sharding()); + } + return broadcast; } void ComputationLowerer::Visit( @@ -2529,8 +2582,11 @@ void ComputationLowerer::Visit( HloInstruction* hlo_instruction = hlo_builder_.AddInstruction(std::move(instruction)); hlo_instruction->set_metadata(request.request().metadata()); - hlo_instruction->set_device_assignment( - request.request().device_assignment()); + if (request.request().has_sharding()) { + OpSharding op_sharding = request.request().sharding(); + hlo_instruction->set_sharding( + HloSharding::FromProto(op_sharding).ValueOrDie()); + } return hlo_instruction; }; auto lookup_instruction = [&](const ComputationDataHandle& handle) { @@ -2983,10 +3039,10 @@ void ComputationLowerer::Visit( HloInstruction* lhs = lookup_instruction(binary_op_request.lhs()); HloInstruction* rhs = lookup_instruction(binary_op_request.rhs()); auto hlo_opcode = BinaryOperationToHloOpcode(binary_op_request.binop()); - if (binary_op_request.broadcast_dimensions_size() > 0) { + if (binary_op_request.broadcast_dimensions_size() > 0 && + ShapeUtil::Rank(lhs->shape()) != ShapeUtil::Rank(rhs->shape())) { // Emit a broadcast instruction to perform the "broadcast in dimension" // operation. - CHECK_NE(ShapeUtil::Rank(lhs->shape()), ShapeUtil::Rank(rhs->shape())); HloInstruction* operand_to_broadcast = ShapeUtil::Rank(lhs->shape()) < ShapeUtil::Rank(rhs->shape()) ? lhs : rhs; diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 6f3bf430fc948732bd771ac3efb60ac9791076d2..ac879ce55a75f6241a39f935b79017be46c1816b 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -250,9 +250,11 @@ class UserComputation { StatusOr> ComputeProgramShape( VersionedComputationHandle::Version version) const; - // Returns true if the given data handle does not depend on any - // parameters. That is, the value can be computed at compile time. - StatusOr IsConstant(const ComputationDataHandle& handle); + // Returns true if the given data handle does not depend on any parameter with + // index higher then num_parameters. That is, the value can be computed at + // compile time if we know the first num_parameters arguments. + StatusOr IsConstant(const ComputationDataHandle& handle, + int64 num_parameters); // Returns the output shape of the operation indicated by the given handle. StatusOr GetShape(const ComputationDataHandle& handle); @@ -262,8 +264,8 @@ class UserComputation { const OpMetadata& metadata); // Sets the device assignment on the Hlo instruction referenced by 'handle'. - Status SetOpDeviceAssignment(const ComputationDataHandle& handle, - const OpDeviceAssignment& device_assignment); + Status SetOpSharding(const ComputationDataHandle& handle, + const OpSharding& sharding); // Builds a HLO computation from the UserComputation. The parameter "resolver" // is a function which returns a pointer to the HloComputation corresponding diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc index 6b0d6b9e11cd638b8f8a2d6f6be7e5a96b351382..5afaf226ae0cce7e9afc966c6b4adf838aeebc91 100644 --- a/tensorflow/compiler/xla/service/user_computation_test.cc +++ b/tensorflow/compiler/xla/service/user_computation_test.cc @@ -224,6 +224,14 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle, computation.AddParameterInstruction(b_request)); + const int64 kDevice = 7; + OpSharding sharding; + sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding.add_tile_assignment_dimensions(1); + sharding.add_tile_assignment_devices(kDevice); + + TF_EXPECT_OK(computation.SetOpSharding(b_handle, sharding)); + BinaryOpRequest add; add.set_binop(BINOP_ADD); *add.mutable_lhs() = a_handle; @@ -249,11 +257,16 @@ TEST_F(UserComputationTest, CheckImplicitBroadcastToExplicitBroadcast) { // \ / // add EXPECT_EQ(5, hlo_computation->instruction_count()); - EXPECT_THAT(hlo_computation->root_instruction(), op::Add()); - const auto& operands = hlo_computation->root_instruction()->operands(); - ASSERT_EQ(2, operands.size()); - EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kParameter && - operands[1]->opcode() == HloOpcode::kBroadcast); + ASSERT_THAT( + hlo_computation->root_instruction(), + op::Add(op::Parameter(), op::Broadcast(op::Reshape(op::Parameter())))); + + const HloInstruction* broadcast = + hlo_computation->root_instruction()->operand(1); + EXPECT_TRUE(broadcast->has_sharding()); + + const HloInstruction* reshape = broadcast->operand(0); + EXPECT_TRUE(reshape->has_sharding()); } TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) { diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc new file mode 100644 index 0000000000000000000000000000000000000000..65734f91bc6ce5d9fa00dae22544dd1f169d861c --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc @@ -0,0 +1,638 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" +#include "tensorflow/compiler/xla/service/call_inliner.h" +#include "tensorflow/compiler/xla/service/hlo_evaluator.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/str_util.h" +#include "tensorflow/core/lib/strings/strcat.h" + +namespace xla { + +using tensorflow::gtl::nullopt; +using tensorflow::gtl::optional; + +// Finds and returns the non-constant operand in instr. +// +// CHECK-fails if instr doesn't have exactly one unique non-constant operand. +static const HloInstruction* NonConstantOperand(const HloInstruction* instr) { + const HloInstruction* result = nullptr; + for (const HloInstruction* operand : instr->operands()) { + if (!operand->IsConstant()) { + if (result != nullptr) { + CHECK_EQ(result, operand); + } + result = operand; + } + } + CHECK_NE(result, nullptr); + return result; +} + +// Determines whether the given instruction is a send/recv node, or has a +// subcomputation which contains a send/recv node. +static bool IsOrContainsSendOrRecv(const HloInstruction* instr); + +// Determines whether the given computation contains a send or recv node. +static bool ContainsSendOrRecv(const HloComputation* comp) { + for (const auto* instr : comp->instructions()) { + if (IsOrContainsSendOrRecv(instr)) { + return true; + } + } + return false; +} + +static bool IsOrContainsSendOrRecv(const HloInstruction* instr) { + if (instr->opcode() == HloOpcode::kSend || + instr->opcode() == HloOpcode::kRecv) { + return true; + } + for (const auto& subcomp : instr->called_computations()) { + if (ContainsSendOrRecv(subcomp)) { + return true; + } + } + return false; +} + +// If all of instr's operands are either constants or have the form +// get-tuple-element(gte_operand, N) +// for the same value N, returns N. Otherwise, returns nullopt. +static optional GetGTEOperandIndex(const HloInstruction* instr, + const HloInstruction* gte_operand) { + VLOG(2) << "GetGTEOperandIndex(" << instr->ToString() << ", " + << gte_operand->ToString() << ")"; + optional tuple_idx; + for (const HloInstruction* operand : instr->operands()) { + if (operand->IsConstant()) { + continue; + } + if (operand->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "instr uses something other than gte(gte_operand): " + << operand->ToString(); + return nullopt; + } + if (operand->operand(0) != gte_operand) { + VLOG(2) << "instr has gte whose operand is not gte_operand: " + << operand->ToString(); + return nullopt; + } + if (tuple_idx && tuple_idx != operand->tuple_index()) { + VLOG(2) << "instr has operands with conflicting gte indices, " + << *tuple_idx << " vs " << operand->tuple_index(); + return nullopt; + } + + tuple_idx = operand->tuple_index(); + } + return tuple_idx; +} + +// Tries to get the tuple index of the induction variable of a while loop. +// +// Checks that the loop condition and root both plumb the induction variable +// through the same tuple index, and that they both apply exactly one op to the +// induction variable before deciding whether to do another loop iteration (in +// the loop condition's case) or packing the induction variable into the result +// tuple (in the loop body's case). +// +// Specifically, checks that the loop condition has structure +// +// root = op(constants, get-tuple-elem(param0, N), constants) +// +// and the loop body has the structure +// +// inc = op(constants, get-tuple-elem(param0, N), constants) +// root = tuple(..., inc, ...) // inc is N'th operand of tuple(). +// +// If so, returns N. Otherwise, returns nullopt. +static optional GetLoopInductionVarTupleIdx( + const HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Finding induction variable for loop " + << while_op->ToShortString(); + + // The while_cond computation should have the form + // + // while_cond_root = + // op(constants, get-tuple-elem(while_cond_param, N), constants). + // + // If it does, set indvar_tuple_idx to N. + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_param = while_cond->parameter_instruction(0); + optional indvar_tuple_idx = + GetGTEOperandIndex(while_cond_root, while_cond_param); + if (!indvar_tuple_idx) { + VLOG(2) << "Induction variable not found in loop condition: " + << while_cond->root_instruction()->ToString(); + return nullopt; + } + + // The while_body computation should have the form + // + // while_body_inc = + // op(constants, get-tuple-elem(while_body_param, N), constants) + // while_body_root = tuple(..., while_body_inc, ...) + // + // where while_body_inc is operand N of while_body_root. + auto* while_body = while_op->while_body(); + auto* while_body_root = while_body->root_instruction(); + if (while_body_root->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While body's root is not a tuple instruction: " + << while_body_root->ToString(); + return nullopt; + } + + auto* while_body_inc = while_body_root->operand(*indvar_tuple_idx); + auto* while_body_param = while_body->parameter_instruction(0); + optional while_body_indvar_tuple_idx = + GetGTEOperandIndex(while_body_inc, while_body_param); + if (!while_body_indvar_tuple_idx) { + VLOG(2) + << "Induction variable not found in while body increment instruction: " + << while_body_inc->ToString(); + return nullopt; + } + if (while_body_indvar_tuple_idx != indvar_tuple_idx) { + VLOG(2) << "Tuple index of induction variable does not match between loop " + "condition (" + << *indvar_tuple_idx << ") and while body (" + << *while_body_indvar_tuple_idx << ")"; + return nullopt; + } + + // Finally, check that the while loop's initial value is a tuple with enough + // elements. + auto* while_init = while_op->operand(0); + if (while_init->opcode() != HloOpcode::kTuple) { + VLOG(2) << "While init expected to be a tuple: " << while_init->ToString(); + return nullopt; + } + + VLOG(2) << "Induction variable's tuple index: " << *indvar_tuple_idx; + return indvar_tuple_idx; +} + +// Tries to determine the number of times the given loop executes. Currently +// simply returns 0, 1, or "can't tell" (nullopt). +static optional GetLoopTripCount(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + VLOG(2) << "Getting trip count for loop " << while_op->ToString(); + + // The loop's induction variable is found at + // + // get-tuple-elem(comp->parameter_instruction(0), *indvar_tuple_idx), + // + // where comp is while_op->while_body() or while_op->while_condition(). + optional indvar_tuple_idx = GetLoopInductionVarTupleIdx(while_op); + if (!indvar_tuple_idx) { + return nullopt; + } + + VLOG(2) << "Induction variable is at index " << *indvar_tuple_idx + << " in input tuple."; + + // Now that we know the index of the induction variable, we can we can try to + // compute how many times the loop executes. Start by computing the induction + // variable's initial value. + HloEvaluator evaluator; + auto* while_init = while_op->mutable_operand(0); + auto* indvar_init = while_init->mutable_operand(*indvar_tuple_idx); + StatusOr> indvar_init_result = + evaluator.Evaluate(indvar_init); + if (!indvar_init_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable init: " + << indvar_init_result.status(); + return nullopt; + } + + // Evaluates the while loop's condition, returning either "true" (continue + // looping), "false" (stop looping), or nullopt (can't evaluate). + auto evaluate_while_cond = [&](const Literal& indvar) -> optional { + auto* while_cond = while_op->while_condition(); + auto* while_cond_root = while_cond->root_instruction(); + auto* while_cond_indvar = NonConstantOperand(while_cond_root); + StatusOr> result = + evaluator.EvaluateWithSubstitutions(while_cond_root, + {{while_cond_indvar, &indvar}}); + if (!result.ok()) { + VLOG(2) << "Couldn't evaluate while cond: " << result.status(); + return nullopt; + } + return result.ValueOrDie()->GetArraySlice() == + tensorflow::gtl::ArraySlice{true}; + }; + + // The initial value of the induction variable. + const Literal& indvar_iter0_val = *indvar_init_result.ValueOrDie(); + + // Evaluate whether the while condition is true when seeded with + // indvar_iter0_val. + optional while_cond_iter0_val = evaluate_while_cond(indvar_iter0_val); + if (while_cond_iter0_val == false) { + VLOG(2) << "Loop has static trip count of 0."; + return 0; + } + + // Calculate the value of the induction variable after one iteration of the + // loop, and check whether the while condition is true with this new value. + auto* while_body = while_op->while_body(); + auto* while_body_indvar_update = + while_body->root_instruction()->operand(*indvar_tuple_idx); + auto* while_body_indvar = NonConstantOperand(while_body_indvar_update); + StatusOr> indvar_iter1_result = + evaluator.EvaluateWithSubstitutions( + while_body_indvar_update, {{while_body_indvar, &indvar_iter0_val}}); + if (!indvar_iter1_result.ok()) { + VLOG(2) << "Couldn't evaluate induction variable update: " + << indvar_iter1_result.status(); + return nullopt; + } + const Literal& indvar_iter1_val = *indvar_iter1_result.ValueOrDie(); + optional while_cond_iter1_val = evaluate_while_cond(indvar_iter1_val); + if (while_cond_iter1_val == false) { + VLOG(2) << "Determined that loop has static trip count of 1."; + return 1; + } + + VLOG(2) << "Loop has unknown trip count >= 1."; + return nullopt; +} + +// Tries to remove elements in a while loop's tuple that aren't used within the +// loop. +// +// Specifically, if a loop is tuple-shaped, and there exists some element of +// that tuple that is not used by the loop condition and is not used by the loop +// body except to pass it to the next iteration of the loop, then we can remove +// that element from the loop's tuples. +static StatusOr TryRemoveDeadWhileParams(HloInstruction* while_op) { + CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); + + // Don't try this transformation if the while loop isn't removable, since if + // it succeeds ultimately we're going to have to replace the old while loop + // with a new one. + if (!while_op->parent()->IsRemovable(while_op)) { + VLOG(2) << "Can't remove dead parameters from non-removable while op."; + return false; + } + + HloModule* module = while_op->GetModule(); + HloComputation* computation = while_op->parent(); + HloInstruction* while_init = while_op->mutable_operand(0); + HloComputation* while_cond = while_op->while_condition(); + HloComputation* while_body = while_op->while_body(); + HloInstruction* while_body_root = while_body->root_instruction(); + + if (!ShapeUtil::IsTuple(while_init->shape())) { + VLOG(2) << "While op's carried value isn't tuple shaped."; + return false; + } + + // Bail if param0 of while_cond or while_body has users which aren't of type + // get-tuple-element. + for (const HloInstruction* instr : {while_body->parameter_instruction(0), + while_cond->parameter_instruction(0)}) { + for (const HloInstruction* user : instr->users()) { + if (user->opcode() != HloOpcode::kGetTupleElement) { + VLOG(2) << "Cowardly refusing to analyze while loop with " + << instr->ToStringNoMetadata() + << " used by non-GTE instruction " << user->ToStringNoMetadata() + << " in computation " << instr->parent()->name(); + return false; + } + } + } + + const int64 tuple_size = ShapeUtil::TupleElementCount(while_init->shape()); + if (tuple_size == 0) { + VLOG(2) << "Can't remove elements from while loop's tuple -- it's already " + "empty."; + return false; + } + + tensorflow::gtl::FlatSet used_tuple_indices; + for (HloComputation* comp : {while_body, while_cond}) { + // The HLO verifier ensures that while_input's shape matches while_init's + // shape, which we verified above is a tuple. + HloInstruction* while_input = comp->parameter_instruction(0); + + for (const HloInstruction* user : while_input->users()) { + // This user doesn't count if it's only used by the while body's root, and + // the root places the tuple element into the same index of the tuple as + // it came from. That just amounts to us carrying the variable through + // the loop. + // + // Careful: HloInstruction::operand_index returns the first index the + // operand appears in, but it may appear more than once! + if (user->user_count() == 1 && user->users()[0] == while_body_root && + while_body_root->operand_index(user) == user->tuple_index() && + std::count(while_body_root->operands().begin(), + while_body_root->operands().end(), user) == 1) { + continue; + } + + used_tuple_indices.insert(user->tuple_index()); + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If a tuple element is not passed unmodified from the while body's param0 + // through to the while body's root, count that element as "used", since + // removing that element would be observable. + for (int64 i = 0; i < while_body_root->operand_count(); ++i) { + if (used_tuple_indices.count(i)) { + continue; + } + + auto* operand = while_body_root->operand(i); + if (operand->opcode() != HloOpcode::kGetTupleElement || + operand->operand(0) != while_body->parameter_instruction(0) || + operand->tuple_index() != i) { + VLOG(2) << "Tuple index " << i + << " is not passed through loop body unmodified."; + used_tuple_indices.insert(i); + + if (used_tuple_indices.size() == tuple_size) { + VLOG(2) << "Loop " << while_op->ToStringNoMetadata() + << " uses all of its inputs; no simplification possible."; + return false; + } + } + } + + // If we got here, used_tuple_indices.size() < tuple_size, meaning some + // elements of the loop's tuple aren't used by while_body or while_cond. + CHECK_LT(used_tuple_indices.size(), tuple_size); + + VLOG(1) << "Eliminating " << tuple_size - used_tuple_indices.size() + << " elements from tuple of " << while_op->ToStringNoMetadata(); + + // Build up maps from the old/new to the new/old tuple indices. + std::vector new_to_old_tuple_idx(used_tuple_indices.begin(), + used_tuple_indices.end()); + std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end()); + + tensorflow::gtl::FlatMap old_to_new_tuple_idx; + for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) { + int64 old_idx = new_to_old_tuple_idx[new_idx]; + old_to_new_tuple_idx[old_idx] = new_idx; + VLOG(2) << "Remapping tuple index " << old_idx << " to " << new_idx; + } + + // Compute the shape of the while op after we remove the dead indices. + std::vector new_while_tuple_elem_shapes; + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_tuple_elem_shapes.push_back( + while_init->shape().tuple_shapes(old_idx)); + } + Shape new_while_shape = + ShapeUtil::MakeTupleShape(new_while_tuple_elem_shapes); + + // Returns a map from elements in the computation to new instructions which + // replace the old instructions after we remove unused elements from the while + // tuple. + auto make_while_computation_replacements = [&](const HloComputation* comp) { + std::unordered_map> + replacements; + + auto* param = comp->parameter_instruction(0); + replacements.emplace(param, HloInstruction::CreateParameter( + 0, new_while_shape, param->name())); + + // Materialize param's users, since we're about to add new ones below. + std::vector materialized_users(param->users().begin(), + param->users().end()); + for (const auto* user : materialized_users) { + // The while body root is handled separately. + if (user == while_body_root) { + continue; + } + CHECK_EQ(user->opcode(), HloOpcode::kGetTupleElement) + << user->ToStringNoMetadata(); + + int64 old_idx = user->tuple_index(); + auto new_idx_iter = old_to_new_tuple_idx.find(old_idx); + if (new_idx_iter != old_to_new_tuple_idx.end()) { + // This is a GTE of an index that survives. Replace it. + replacements.emplace( + user, HloInstruction::CreateGetTupleElement(user->shape(), param, + new_idx_iter->second)); + } else { + // This is a GTE of an index that we've removed. Remove it from the + // cloned computation. + CHECK(user->user_count() == 0 || + user->user_count() == 1 && user->users()[0] == while_body_root) + << "Instruction " << user->ToStringNoMetadata() + << " should be unused (except by root of while body), but has " + "users: {" + << tensorflow::str_util::Join( + user->users(), ", ", + [](string* out, const HloInstruction* instr) { + tensorflow::strings::StrAppend( + out, instr->ToStringNoMetadata()); + }) + << "}"; + + replacements.emplace(user, nullptr); + } + } + return replacements; + }; + + // Create the new while condition, body, and init value. + std::unique_ptr new_while_cond = + while_cond->CloneWithReplacements( + make_while_computation_replacements(while_cond)); + + std::unordered_map> + while_body_replacements = make_while_computation_replacements(while_body); + std::vector new_while_body_root_elems; + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_body_root_elems.push_back( + while_body_root->mutable_operand(old_idx)); + } + while_body_replacements.emplace( + while_body_root, HloInstruction::CreateTuple(new_while_body_root_elems)); + std::unique_ptr new_while_body = + while_body->CloneWithReplacements(std::move(while_body_replacements)); + + // Add a new while_init instruction that repackages the old while_init + // instruction's elements. We rely on the AlgebraicSimplifier and DCE to + // clean this up in the common case where while_init is a tuple op. (It's + // definitely tuple-shaped, but it's not necessarily a tuple op.) + std::vector new_while_init_elems; + for (int64 old_idx : new_to_old_tuple_idx) { + new_while_init_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); + } + auto* new_while_init = computation->AddInstruction( + HloInstruction::CreateTuple(new_while_init_elems)); + + // Create the new while op. + auto* new_while_op = computation->AddInstruction(HloInstruction::CreateWhile( + new_while_shape, + module->AddEmbeddedComputation(std::move(new_while_cond)), + module->AddEmbeddedComputation(std::move(new_while_body)), + new_while_init)); + + // Create a tuple op that recreates the output of the old while op. That is, + // we transform to + // + // new_while_init while_init + // | | + // V | + // new_while | + // | | + // -------| |---- + // V V + // new_tuple + // | + // V + // (orig. users of while op) + // + // The tuple simplifier will then simplify this if possible, removing + // new_tuple and while_init. + std::vector new_tuple_elems; + for (int64 old_idx = 0; old_idx < tuple_size; ++old_idx) { + auto new_tuple_idx_it = old_to_new_tuple_idx.find(old_idx); + if (new_tuple_idx_it != old_to_new_tuple_idx.end()) { + int64 gte_idx = new_tuple_idx_it->second; + new_tuple_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + new_while_op->shape().tuple_shapes(gte_idx), new_while_op, + gte_idx))); + } else { + new_tuple_elems.push_back( + computation->AddInstruction(HloInstruction::CreateGetTupleElement( + while_init->shape().tuple_shapes(old_idx), while_init, old_idx))); + } + } + HloInstruction* new_tuple = + computation->AddInstruction(HloInstruction::CreateTuple(new_tuple_elems)); + TF_RETURN_IF_ERROR(while_op->ReplaceAllUsesWith(new_tuple)); + + return true; +} + +// Tries to remove a while loop from the graph. +// +// - Loops with trip count of 0 can be replaced by the loop's "init" value. +// - Loops with trip count of 1 can be replaced by the loop's body, with the +// loop itself removed. +// +// Returns true if it made a change to the graph. +static StatusOr TryRemoveWhileLoop(HloInstruction* while_op) { + // Cowardly refuse to remove loops that are not removable. In practice, + // this means that we can't remove loops that contain side-effecting + // instructions or have control predecessors/successors. + // + // This is not a fundamental limitation. The control operands can be moved + // onto the new HLOs after simplification, and any side-effecting ops inside + // the loop aren't removed, just cloned and added back to the loop. + // Nevertheless our infrastructure sees loop simplification as removal of + // these nodes and currently doesn't allow it. + if (!while_op->parent()->IsRemovable(while_op)) { + VLOG(2) << "Not attempting to remove while loop it is not removable: " + << while_op->ToShortString(); + return false; + } + + // Remove while loops with static trip count of 0. + optional trip_count = GetLoopTripCount(while_op); + if (trip_count && *trip_count == 0) { + // The loop never executes, so the value of the loop is the value of its + // "init" operand. + auto computation = while_op->parent(); + + // Remove while_op (i.e., call ReplaceInstruction rather than + // ReplaceUsesWithInstruction) so that if the algebraic simplifier is run in + // a loop without an intervening DCE, we don't try to re-remove the loop. + TF_RETURN_IF_ERROR(computation->ReplaceInstruction( + while_op, while_op->mutable_operand(0))); + return true; + } + + // Transform while loops with static trip count of 1 into a call op, then + // inline the call. + if (trip_count && *trip_count == 1) { + auto computation = while_op->parent(); + auto call_op = computation->AddInstruction(HloInstruction::CreateCall( + while_op->shape(), while_op->operands(), while_op->while_body())); + TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, call_op)); + TF_RETURN_IF_ERROR(CallInliner::Inline(call_op)); + return true; + } + return false; +} + +StatusOr WhileLoopSimplifier::Run(HloModule* module) { + XLA_VLOG_LINES(3, + "WhileLoopSimplifier::Run(), before:\n" + module->ToString()); + bool changed = false; + + // Gather all the while ops in our module. We do this ahead of time so we + // don't have to worry about mutating the lists of computations or + // instructions while we iterate. + std::vector while_ops; + for (auto* comp : module->computations()) { + for (auto* instr : comp->instructions()) { + if (instr->opcode() == HloOpcode::kWhile) { + while_ops.push_back(instr); + } + } + } + + for (HloInstruction* while_op : while_ops) { + // We can't remove while loops that contain send/recv nodes, because we rely + // on the particular loop structure around the node matching on the send and + // recv sides. Removing dead while params requires us to remove the loop + // and replace it with a new one, so we can't do that either. + if (ContainsSendOrRecv(while_op->while_body()) || + ContainsSendOrRecv(while_op->while_condition())) { + VLOG(2) << "Not attempting to simplify while loop because it contains a " + "send/recv node: " + << while_op->ToShortString(); + continue; + } + + StatusOr result = TryRemoveWhileLoop(while_op); + TF_RETURN_IF_ERROR(result.status()); + if (result.ValueOrDie()) { + changed = true; + // Don't try to remove dead while params after successfully removing the + // while loop -- that would result in use-after-free nastiness. + continue; + } + + result = TryRemoveDeadWhileParams(while_op); + TF_RETURN_IF_ERROR(result.status()); + changed |= result.ValueOrDie(); + } + + XLA_VLOG_LINES(3, + "WhileLoopSimplifier::Run(), after:\n" + module->ToString()); + return changed; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.h b/tensorflow/compiler/xla/service/while_loop_simplifier.h new file mode 100644 index 0000000000000000000000000000000000000000..50dac32a4ab0a5de756c1ddf5e62c3560e54a079 --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier.h @@ -0,0 +1,44 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ +#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_pass_interface.h" +#include "tensorflow/compiler/xla/statusor.h" + +namespace xla { + +// HLO pass that makes the following transformations on while loops: +// +// - A while loop with static trip count of 0 is deleted. +// - A while loops with static trip count of 1 is replaced by its body (sans +// loop). +// - Elements of a while loop's tuple that the loop doesn't use are removed +// from the tuple. +// +class WhileLoopSimplifier : public HloPassInterface { + public: + ~WhileLoopSimplifier() override {} + tensorflow::StringPiece name() const override { + return "simplify-while-loops"; + } + StatusOr Run(HloModule* module) override; +}; + +} // namespace xla + +#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_WHILE_LOOP_SIMPLIFIER_H_ diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..8e1a2dcde129e9a022789eb7b192319901b9db4a --- /dev/null +++ b/tensorflow/compiler/xla/service/while_loop_simplifier_test.cc @@ -0,0 +1,420 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/while_loop_simplifier.h" + +#include "tensorflow/compiler/xla/service/hlo_matchers.h" +#include "tensorflow/compiler/xla/test.h" +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" +#include "tensorflow/core/lib/core/status_test_util.h" + +namespace xla { +namespace { + +namespace op = xla::testing::opcode_matchers; + +class WhileLoopSimplifierTest : public HloVerifiedTestBase { + public: + // Makes a computation that contains a loop that runs num_iters times. + HloComputation* MakeSimpleLoop(int num_iters, HloModule* module); + + // Makes a computation which has one parameter, of the given shape, and always + // returns PRED[]{true}. This is useful as a dummy loop condition. + HloComputation* MakeAlwaysTrueComputation(const Shape& param_shape, + HloModule* module); +}; + +HloComputation* WhileLoopSimplifierTest::MakeSimpleLoop(int num_iters, + HloModule* module) { + HloComputation::Builder builder(TestName()); + + auto loop_iter_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + auto loop_data_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR1({0, 1, 2}))); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({loop_iter_init, loop_data_init})); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".condition"); + auto loop_var = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto loop_induction_var = + cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(S32, {}), loop_var, 0)); + auto limit = cond_builder.AddInstruction(HloInstruction::CreateConstant( + Literal::CreateR0(42 + num_iters))); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, loop_induction_var, + limit)); + condition = module->AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto loop_var = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto loop_induction_var = + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + ShapeUtil::MakeShape(S32, {}), loop_var, 0)); + auto new_loop_induction_var = + body_builder.AddInstruction(HloInstruction::CreateBinary( + loop_induction_var->shape(), HloOpcode::kAdd, loop_induction_var, + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))))); + auto loop_data = + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + loop_data_init->shape(), loop_var, 1)); + auto new_loop_data = + body_builder.AddInstruction(HloInstruction::CreateBinary( + loop_data_init->shape(), HloOpcode::kMultiply, loop_data, + loop_data)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({new_loop_induction_var, new_loop_data})); + body = module->AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + return module->AddEntryComputation(builder.Build()); +} + +HloComputation* WhileLoopSimplifierTest::MakeAlwaysTrueComputation( + const Shape& param_shape, HloModule* module) { + HloComputation::Builder builder(TestName() + ".always_true"); + builder.AddInstruction( + HloInstruction::CreateParameter(0, param_shape, "param")); + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + return module->AddEmbeddedComputation(builder.Build()); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithZeroIterations) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/0, &module()); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Constant(), op::Constant())); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithOneIteration) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction(), + op::Tuple(op::Add(), op::Multiply())); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithTwoIterations) { + MakeSimpleLoop(/*num_iters=*/2, &module()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, WhileLoopWithControlDependency) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* true_op = while_op->while_body()->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))); + TF_ASSERT_OK(true_op->AddControlDependencyTo( + while_op->while_body()->root_instruction())); + ASSERT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + EXPECT_THAT(computation->root_instruction()->control_predecessors(), + ElementsAre(op::Constant())) + << computation->ToString(); +} + +// Loops that contain send/recv nodes can't be simplified; the loop structure +// around send/recv nodes must be preserved. +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsSend) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + while_body->AddInstruction(HloInstruction::CreateSend( + while_body->AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(true))), + /*channel_id=*/0)); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsRecv) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + while_body->AddInstruction( + HloInstruction::CreateRecv(ShapeUtil::MakeShape(F32, {1}), + /*channel_id=*/0)); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// The limitation on not being able to simplify loops that contain infeeds (and +// other non-removable instructions) isn't fundamental -- it just stems from the +// fact that our infrastructure sees simplifying such a loop as tantamount to +// removing the non-removable instruction. +TEST_F(WhileLoopSimplifierTest, NotRemovedIfContainsNonRemovableInstruction) { + HloComputation* computation = MakeSimpleLoop(/*num_iters=*/1, &module()); + auto* while_op = computation->root_instruction(); + ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); + auto* while_body = while_op->while_body(); + while_body->AddInstruction( + HloInstruction::CreateInfeed(ShapeUtil::MakeShape(F32, {1}), "config")); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Check that we don't crash when given a loop whose shape is not a tuple. +TEST_F(WhileLoopSimplifierTest, IgnoreNonTupleShapedLoop) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(42))); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".condition"); + auto param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kLt, param, + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(100))))); + condition = module().AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(S32, {}), HloOpcode::kAdd, param, + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(-1))))); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Construct a loop where we swap the tuple elements in each iteration. +// Although the tuple elements aren't used in the loop, we don't eliminate them, +// because the swapping side-effect is visible to users of the loop. +TEST_F(WhileLoopSimplifierTest, SwapTupleIndices) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + body_builder.AddInstruction(HloInstruction::CreateTuple({ + body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 1)), + body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, 0)), + })); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Construct a loop where we assign a constant to tuple element 0 in each +// iteration. We can't eliminate tuple element 0, even though we never use its +// value. +TEST_F(WhileLoopSimplifierTest, UnusedButModifiedTupleElement) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction( + HloInstruction::CreateTuple({builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0)))})); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateTuple({ + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// Nothing to simplify in a while loop whose tuple has 0 elements. +TEST_F(WhileLoopSimplifierTest, EmptyTuple) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({})); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "loop_var")); + body_builder.AddInstruction(HloInstruction::CreateTuple({})); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// While loop where one tuple element is used twice in the body, and thus can't +// be simplified away. +TEST_F(WhileLoopSimplifierTest, ElemUsedTwice) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))), + })); + + HloComputation* condition = + MakeAlwaysTrueComputation(loop_init->shape(), &module()); + + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto* param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_init->shape(), "param0")); + auto* gte0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); + // get0 is used twice in the loop body's tuple. + body_builder.AddInstruction(HloInstruction::CreateTuple({gte0, gte0})); + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + module().AddEntryComputation(builder.Build()); + EXPECT_FALSE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); +} + +// This while loop has three tuple elements. Element 0 is unused and should be +// removed. Element 1 is used by the loop body, and element 2 is used by the +// loop condition; these two should stay. +TEST_F(WhileLoopSimplifierTest, RemoveUnusedOperand) { + HloComputation::Builder builder(TestName()); + auto loop_init = builder.AddInstruction(HloInstruction::CreateTuple({ + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + })); + auto loop_shape = loop_init->shape(); + auto scalar_s32 = ShapeUtil::MakeShape(S32, {}); + + HloComputation* condition; + { + HloComputation::Builder cond_builder(TestName() + ".loop_condition"); + auto param = cond_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_shape, "param0")); + cond_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(PRED, {}), HloOpcode::kEq, + cond_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(0))), + cond_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_s32, param, /*index=*/2)))); + condition = module().AddEmbeddedComputation(cond_builder.Build()); + } + + HloComputation* body; + { + HloComputation::Builder body_builder(TestName() + ".body"); + auto* param = body_builder.AddInstruction( + HloInstruction::CreateParameter(0, loop_shape, "loop_var")); + + auto* tuple0 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/0)); + auto* tuple1 = body_builder.AddInstruction(HloInstruction::CreateBinary( + scalar_s32, HloOpcode::kAdd, + body_builder.AddInstruction(HloInstruction::CreateGetTupleElement( + scalar_s32, param, /*index=*/1)), + body_builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(1))))); + auto* tuple2 = body_builder.AddInstruction( + HloInstruction::CreateGetTupleElement(scalar_s32, param, /*index=*/2)); + body_builder.AddInstruction( + HloInstruction::CreateTuple({tuple0, tuple1, tuple2})); + + body = module().AddEmbeddedComputation(body_builder.Build()); + } + + auto* while_op = builder.AddInstruction(HloInstruction::CreateWhile( + loop_init->shape(), condition, body, loop_init)); + + module().AddEntryComputation(builder.Build()); + EXPECT_TRUE(WhileLoopSimplifier().Run(&module()).ValueOrDie()); + + // We leave most of the checking to HloVerifiedTestBase, which runs the + // verifier on module() at the end of this test. + HloInstruction* new_while_op = *std::find_if( + module().entry_computation()->instructions().begin(), + module().entry_computation()->instructions().end(), + [&](const HloInstruction* instr) { + return instr != while_op && instr->opcode() == HloOpcode::kWhile; + }); + EXPECT_TRUE( + ShapeUtil::Equal(new_while_op->shape(), + ShapeUtil::MakeTupleShape({scalar_s32, scalar_s32}))) + << ShapeUtil::HumanString(new_while_op->shape()); + EXPECT_THAT( + new_while_op->while_body()->root_instruction(), + op::Tuple( + op::Add(op::GetTupleElement(op::Parameter(0), /*tuple_index=*/0), + op::Constant()), + op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); + + EXPECT_THAT(new_while_op->while_condition()->root_instruction(), + op::Eq(op::Constant(), + op::GetTupleElement(op::Parameter(0), /*tuple_index=*/1))); +} + +} // namespace +} // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index 8e16056b239a9e1d1776bfe91f6e36862e0feeec..b5eb81dfc6a4117909dcb18fdbe61443b1a1eb95 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -102,6 +102,32 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { return true; } +// Constructs and returns the new shape with the given minor_to_major order in +// its Layout. +StatusOr MakeShapeWithLayoutInternal( + PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, + tensorflow::gtl::ArraySlice minor_to_major) { + if (dimensions.size() != minor_to_major.size()) { + return InvalidArgument("Dimensions size is %ld, but layout size is %ld.", + dimensions.size(), minor_to_major.size()); + } + if (element_type == OPAQUE || element_type == TUPLE) { + return InvalidArgument("Unsupported element type: %s", + PrimitiveType_Name(element_type).c_str()); + } + Shape shape = ShapeUtil::MakeShape(element_type, dimensions); + auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); + min2maj->Clear(); + for (int64 value : minor_to_major) { + min2maj->Add(value); + } + if (!shape.has_layout()) { + return InvalidArgument("Shape has no layout."); + } + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(shape)); + return shape; +} + } // namespace /* static */ bool ShapeUtil::Equal(const Shape& lhs, const Shape& rhs) { @@ -152,16 +178,8 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { /* static */ Shape ShapeUtil::MakeShapeWithLayout( PrimitiveType element_type, tensorflow::gtl::ArraySlice dimensions, tensorflow::gtl::ArraySlice minor_to_major) { - CHECK_EQ(dimensions.size(), minor_to_major.size()); - Shape shape = MakeShape(element_type, dimensions); - auto min2maj = shape.mutable_layout()->mutable_minor_to_major(); - min2maj->Clear(); - for (int64 value : minor_to_major) { - min2maj->Add(value); - } - DCHECK(shape.has_layout()); - TF_DCHECK_OK(ValidateShape(shape)); - return shape; + return MakeShapeWithLayoutInternal(element_type, dimensions, minor_to_major) + .ValueOrDie(); } /* static */ Shape ShapeUtil::MakeShapeWithMonotonicDim0MajorLayout( @@ -254,6 +272,7 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { case U16: case U32: case U64: + case C64: case TUPLE: case OPAQUE: return false; @@ -263,6 +282,10 @@ bool CompareShapes(const Shape& lhs, const Shape& rhs, bool compare_layouts) { } } +/* static */ bool ShapeUtil::ElementIsComplex(const Shape& shape) { + return primitive_util::IsComplexType(shape.element_type()); +} + /* static */ bool ShapeUtil::ElementIsFloating(const Shape& shape) { return primitive_util::IsFloatingPointType(shape.element_type()); } @@ -499,11 +522,10 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { // Extract the layout minor-to-major and set it. TF_ASSIGN_OR_RETURN(std::vector min2maj, comma_list_to_int64s(layout_string)); - TF_RET_CHECK(dimensions.size() == min2maj.size()); - result = - ShapeUtil::MakeShapeWithLayout(primitive_type, dimensions, min2maj); + TF_ASSIGN_OR_RETURN(result, MakeShapeWithLayoutInternal( + primitive_type, dimensions, min2maj)); } - TF_DCHECK_OK(ShapeUtil::ValidateShape(result)); + TF_RETURN_IF_ERROR(ShapeUtil::ValidateShape(result)); return std::move(result); } @@ -575,6 +597,8 @@ StatusOr ParseShapeStringInternal(tensorflow::StringPiece* s) { return sizeof(float); case F64: return sizeof(double); + case C64: + return sizeof(complex64); default: LOG(FATAL) << "Unhandled primitive type " << primitive_type; } diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h index 140388f9c067fe5582df0865f0bbc7db4952c31a..8f8d4a73c9ecb3f4236f3877323ad1127bb0b9c2 100644 --- a/tensorflow/compiler/xla/shape_util.h +++ b/tensorflow/compiler/xla/shape_util.h @@ -66,6 +66,8 @@ class ShapeIndex { std::vector::iterator begin() { return indices_.begin(); } std::vector::iterator end() { return indices_.end(); } + const int64* data() const { return indices_.data(); } + const int64& operator[](size_t i) const { return indices_[i]; } int64& operator[](size_t i) { return indices_[i]; } @@ -81,20 +83,20 @@ class ShapeIndex { private: std::vector indices_; - - friend class ShapeIndexView; }; // A view into a ShapeIndex as above, with the cheap/easy ability to consume the // value at the front of the view. +// +// NB! ShapeIndexView does not own the memory backing the index array. +// The memory backing the index array should be owned by an object +// that lives longer than the ShapeIndexView instances pointing into +// it. class ShapeIndexView { public: - ShapeIndexView(const ShapeIndex& shape_index) - : ShapeIndexView(shape_index.indices_.data(), - shape_index.indices_.data() + shape_index.size()) {} - ShapeIndexView(const ShapeIndex& shape_index, int64 offset) - : ShapeIndexView(shape_index.indices_.data() + offset, - shape_index.indices_.data() + shape_index.size()) { + ShapeIndexView(const ShapeIndex& shape_index, int64 offset = 0) + : ShapeIndexView(shape_index.data() + offset, + shape_index.data() + shape_index.size()) { CHECK_LE(offset, shape_index.size()); } ShapeIndexView(std::initializer_list indices) @@ -289,6 +291,9 @@ class ShapeUtil { // Returns whether the element type of the shape is floating point. static bool ElementIsFloating(const Shape& shape); + // Returns whether the element type of the shape is complex. + static bool ElementIsComplex(const Shape& shape); + // Returns whether the element type has the given bit width. static bool ElementHasBitWidth(const Shape& shape, int bits); diff --git a/tensorflow/compiler/xla/shape_util_test.cc b/tensorflow/compiler/xla/shape_util_test.cc index 79945b9c77299b7006d014aed4507566e3c2c750..0ba542ad1bec290c35c52a8dd5177893770310fd 100644 --- a/tensorflow/compiler/xla/shape_util_test.cc +++ b/tensorflow/compiler/xla/shape_util_test.cc @@ -218,6 +218,10 @@ TEST(ShapeUtilTest, ByteSizeOfWithoutPadding) { EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(F64)); EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {}))); EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(F64, {10, 20}))); + + EXPECT_EQ(8, ShapeUtil::ByteSizeOfPrimitiveType(C64)); + EXPECT_EQ(8, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {}))); + EXPECT_EQ(1600, ShapeUtil::ByteSizeOf(ShapeUtil::MakeShape(C64, {10, 20}))); } TEST(ShapeUtilTest, ByteSizeOfWithPadding) { diff --git a/tensorflow/compiler/xla/test_helpers.h b/tensorflow/compiler/xla/test_helpers.h index 634cdb5aa29651b08090ff99f0a6cafb9facb645..17bae2e4f611268df824ce793c75ba1c95573455 100644 --- a/tensorflow/compiler/xla/test_helpers.h +++ b/tensorflow/compiler/xla/test_helpers.h @@ -62,9 +62,16 @@ inline const ::tensorflow::Status& GetStatus(const StatusOr& status) { #define EXPECT_IS_OK(expression) \ EXPECT_EQ(tensorflow::Status::OK(), \ xla::testing::internal_status::GetStatus(expression)) +#define EXPECT_IS_NOT_OK(expression) \ + EXPECT_NE(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) #undef ASSERT_IS_OK #define ASSERT_IS_OK(expression) \ ASSERT_EQ(tensorflow::Status::OK(), \ xla::testing::internal_status::GetStatus(expression)) +#undef ASSERT_IS_NOT_OK +#define ASSERT_IS_NOT_OK(expression) \ + ASSERT_NE(tensorflow::Status::OK(), \ + xla::testing::internal_status::GetStatus(expression)) #endif // TENSORFLOW_COMPILER_XLA_TEST_HELPERS_H_ diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD index e45b839afd2a9666215744f904dfbed5eca0a41b..4e1be24b61cc436b0baf62cc6e28ad8d13fe71ac 100644 --- a/tensorflow/compiler/xla/tests/BUILD +++ b/tensorflow/compiler/xla/tests/BUILD @@ -23,7 +23,6 @@ filegroup( ]), ) -load("//tensorflow/compiler/xla:xla.bzl", "export_dynamic_linkopts") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "xla_test_library") load("//tensorflow/compiler/xla/tests:build_defs.bzl", "generate_backend_suites") @@ -102,28 +101,34 @@ cc_library( deps = [ ":literal_test_util", "//tensorflow/compiler/xla:shape_layout", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:statusor", "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", - "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:backend", - "//tensorflow/compiler/xla/service:compiler", "//tensorflow/compiler/xla/service:computation_layout", - "//tensorflow/compiler/xla/service:computation_placer", - "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:hlo", - "//tensorflow/compiler/xla/service:hlo_execution_profile", - "//tensorflow/compiler/xla/service:hlo_graph_dumper", - "//tensorflow/compiler/xla/service:transfer_manager", - "//tensorflow/core:core_cpu_internal", + "//tensorflow/compiler/xla/service:hlo_runner", "//tensorflow/core:lib", "//tensorflow/core:stream_executor_no_cuda", "//tensorflow/core:test", - "//third_party/eigen3", + ], +) + +cc_library( + name = "hlo_verified_test_base", + testonly = True, + srcs = ["hlo_verified_test_base.cc"], + hdrs = ["hlo_verified_test_base.h"], + deps = [ + ":hlo_test_base", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:status_macros", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_verifier", + "//tensorflow/core:lib", + "//tensorflow/core:test", ], ) @@ -373,6 +378,7 @@ xla_test( name = "params_test", srcs = ["params_test.cc"], shard_count = 30, + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:array2d", "//tensorflow/compiler/xla:literal_util", @@ -914,6 +920,7 @@ xla_test( name = "reduce_window_test", timeout = "long", srcs = [], + tags = ["optonly"], xla_test_library_deps = [":reduce_window_test_library"], deps = [], ) @@ -981,13 +988,13 @@ xla_test( xla_test( name = "custom_call_test", srcs = ["custom_call_test.cc"], - linkopts = export_dynamic_linkopts, deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service/cpu:custom_call_target_registry", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", @@ -1394,8 +1401,10 @@ xla_test( "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:literal_test_util", "//tensorflow/compiler/xla/tests:xla_internal_test_main", + "//tensorflow/core:core_cpu_internal", "//tensorflow/core:lib", "//tensorflow/core:test", + "//third_party/eigen3", ], ) @@ -1461,6 +1470,7 @@ xla_test( xla_test( name = "local_client_execute_test", srcs = ["local_client_execute_test.cc"], + tags = ["optonly"], deps = [ "//tensorflow/compiler/xla:literal_util", "//tensorflow/compiler/xla:shape_util", diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 532e2394c0d727d77ec0e4ed23f81fdc34a950a6..0b700fbb6ffbde147c71b76d37f334a53c91f2fd 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -496,58 +496,315 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulTwoConstantU32s) { ComputeAndCompareR1(&builder, expected, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAnd) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {false, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalAndZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, AndPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.And(a, b); + + Array2D expected_array({{false, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalAnd(a, b); + auto out = builder.And(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOr) { +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, -8}); + auto b = builder.ConstantR1({5, -7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, -7, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -5}, {-1, 5}}); + auto b = builder.ConstantR2({{1, -6}, {4, 5}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, -6}, {4, 5}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 12}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {0, 1, 8}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {3, 8}}); + auto b = builder.ConstantR2({{1, 0}, {7, 6}}); + auto out = builder.And(a, b); + + Array2D expected_array({{0, 0}, {3, 0}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, AndZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.And(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, false, true, true}); auto b = builder.ConstantR1({false, true, false, true}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {false, true, true, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalOrZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, OrPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, false}, {true, true}}); + auto b = builder.ConstantR2({{false, true}, {false, true}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{false, true}, {true, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); auto b = builder.ConstantR1({}); - auto out = builder.LogicalOr(a, b); + auto out = builder.Or(a, b); ComputeAndCompareR1(&builder, {}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNot) { +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, -1, 8}); + auto b = builder.ConstantR1({5, -7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, -1, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, -1}, {8, 8}}); + auto b = builder.ConstantR2({{5, -7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, -1}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 1, 8}); + auto b = builder.ConstantR1({5, 7, 4}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {5, 7, 12}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 1}, {8, 8}}); + auto b = builder.ConstantR2({{5, 7}, {4, 1}}); + auto out = builder.Or(a, b); + + Array2D expected_array({{5, 7}, {12, 9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, OrZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto b = builder.ConstantR1({}); + auto out = builder.Or(a, b); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({false, true, true, false}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {true, false, false, true}, {}); } -XLA_TEST_F(ArrayElementwiseOpTest, LogicalNotZeroElement) { +XLA_TEST_F(ArrayElementwiseOpTest, NotPredR2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{false, true}, {true, false}}); + auto out = builder.Not(a); + + Array2D expected_array({{true, false}, {false, true}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementPredR1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR1({}); - auto out = builder.LogicalNot(a); + auto out = builder.Not(a); ComputeAndCompareR1(&builder, {}, {}); } +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({-1, 0, 1}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {0, -1, -2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotS32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{-1, 0}, {1, 8}}); + auto out = builder.Not(a); + + Array2D expected_array({{0, -1}, {-2, -9}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementS32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0, 4294967295}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {4294967295, 0}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotU32R2) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR2({{0, 4294967295}, {1, 4294967294}}); + auto out = builder.Not(a); + + Array2D expected_array({{4294967295, 0}, {4294967294, 1}}); + ComputeAndCompareR2(&builder, expected_array, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NotZeroElementU32R1) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({}); + auto out = builder.Not(a); + + ComputeAndCompareR1(&builder, {}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x12345678), + static_cast(0xF0001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, + {static_cast(0x23456780), 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, + {static_cast(0xF9234567), + static_cast(0x00100010), 0, 0, 19}, + {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalS32) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR1({static_cast(0x92345678), + static_cast(0x10001000), 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftLeftU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x12345678, 0xF0001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 15}); + auto out = builder.ShiftLeft(a, b); + + ComputeAndCompareR1( + &builder, {0x23456780, 0x00100000, 0x4, 0x180, 2523136}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightArithmeticU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 2}); + auto out = builder.ShiftRightArithmetic(a, b); + + ComputeAndCompareR1(&builder, {0xF9234567, 0x00100010, 0, 0, 19}, {}); +} + +XLA_TEST_F(ArrayElementwiseOpTest, ShiftRightLogicalU32) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({0x92345678, 0x10001000, 1, 3, 77}); + auto b = builder.ConstantR1({4, 8, 2, 7, 5}); + auto out = builder.ShiftRightLogical(a, b); + + ComputeAndCompareR1(&builder, {0x09234567, 0x00100010, 0, 0, 2}, {}); +} + XLA_TEST_F(ArrayElementwiseOpTest, CompareEqF32s) { SetFastMathDisabled(true); ComputationBuilder builder(client_, TestName()); @@ -1770,7 +2027,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ne) { const string expected = R"(pred[2,2] { { 00 }, - { 01 }, + { 01 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -1784,7 +2041,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Ge) { const string expected = R"(pred[2,4] { { 1100 }, - { 0001 }, + { 0001 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -1798,7 +2055,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Gt) { const string expected = R"(pred[2,4] { { 0100 }, - { 0000 }, + { 0000 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -1812,7 +2069,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Le) { const string expected = R"(pred[2,4] { { 1011 }, - { 1111 }, + { 1111 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -1826,7 +2083,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Lt) { const string expected = R"(pred[2,4] { { 0011 }, - { 1110 }, + { 1110 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -2142,6 +2399,33 @@ XLA_TEST_F(ArrayElementwiseOpTest, CannotAddOpaques) { "Expected non-opaque argument for lhs of binary operation")); } +XLA_TEST_F(ArrayElementwiseOpTest, IdentityBroadcastOfSameRankIsAllowed) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = + builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + auto add = builder.Add(a, b, /*broadcast_dimensions=*/{0, 1}); + + Array2D expected_array( + {{-4.0f, 11.28f, 43.0f}, {1.25f, -14.0f, 8.88f}}); + ComputeAndCompareR2(&builder, expected_array, {}, error_spec_); +} + +XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) { + ComputationBuilder builder(client_, TestName()); + auto a = + builder.ConstantR2({{-2.5f, 3.14f, 1.0f}, {2.25f, -10.0f, 3.33f}}); + auto b = + builder.ConstantR2({{-1.5f, 8.14f, 42.0}, {-1.0f, -4.0f, 5.55f}}); + auto add = builder.Add(a, b, /*broadcast_dimensions=*/{1, 0}); + + StatusOr computation_status = builder.Build(); + ASSERT_FALSE(computation_status.ok()); + EXPECT_THAT(computation_status.status().error_message(), + ::testing::ContainsRegex("must.*be the identity")); +} + // Regression test for b/31927799. "slice - y" is fused and requires implicit // broadcast. XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) { diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc index 4f26bf47ae6d29f525351692612648d6432f9518..03f5e08315bfed2bcb43ebb7098aaa0b97228605 100644 --- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc +++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc @@ -96,7 +96,7 @@ class BroadcastSimpleTest : public ClientLibraryTestBase { } default: { // Default to Add - CHECK(false); + LOG(FATAL); } } } @@ -159,7 +159,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) { } // Tests implicit broadcasting of PREDs. -XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { +XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) { ComputationBuilder b(client_, TestName()); Array2D x_vals(2, 1); @@ -174,7 +174,7 @@ XLA_TEST_F(BroadcastSimpleTest, LogicalAnd2DTo3D_Pred) { ComputationDataHandle x, y; auto x_data = CreateR2Parameter(x_vals, 0, "x", &b, &x); auto y_data = CreateR3Parameter(y_vals, 1, "y", &b, &y); - b.LogicalAnd(x, y, /*broadcast_dimensions=*/{1, 2}); + b.And(x, y, /*broadcast_dimensions=*/{1, 2}); Array3D expected(2, 2, 1); expected(0, 0, 0) = false; diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc index 9f3b66e256dbb351b76a2e66912d3100495101be..065bce7e3146c93568bbce2b0e7e23ddddc4ea31 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.cc +++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc @@ -40,7 +40,7 @@ namespace { Client* GetOrCreateLocalClientOrDie(const LocalClientOptions& client_options) { StatusOr result = ClientLibrary::GetOrCreateLocalClient(client_options); - TF_CHECK_OK(result.status()) << "could not create local client for testing"; + TF_CHECK_OK(result.status()) << " could not create local client for testing"; return result.ValueOrDie(); } } // namespace @@ -254,7 +254,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( tensorflow::gtl::ArraySlice arguments, const Shape* shape_with_layout) { TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); - if (ShapeUtil::ElementIsFloating(expected.shape())) { + if (ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())) { LOG(WARNING) << "performing exact comparison of floating point numbers"; } else { TF_RET_CHECK(ShapeUtil::ElementIsIntegral(expected.shape()) || @@ -282,7 +283,8 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus( ComputationBuilder* builder, const Literal& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error, const Shape* shape_with_layout) { - TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape())); + TF_RET_CHECK(ShapeUtil::ElementIsFloating(expected.shape()) || + ShapeUtil::ElementIsComplex(expected.shape())); TF_ASSIGN_OR_RETURN(auto computation, builder->Build()); auto expect_near = [&](const Literal& actual, const string& error_message) { LiteralTestUtil::ExpectNear(expected, actual, error, error_message); diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h index 7fe1445b94097f762b777fc6936a0a1ab5a726c8..7cfc276ec19e3b177f87a08e716cb34b7676dd6b 100644 --- a/tensorflow/compiler/xla/tests/client_library_test_base.h +++ b/tensorflow/compiler/xla/tests/client_library_test_base.h @@ -361,8 +361,9 @@ void ClientLibraryTestBase::ComputeAndCompareR2( ComputationBuilder* builder, const Array2D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR2FromArray2D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -384,8 +385,9 @@ void ClientLibraryTestBase::ComputeAndCompareR3( ComputationBuilder* builder, const Array3D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR3FromArray3D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, @@ -407,8 +409,9 @@ void ClientLibraryTestBase::ComputeAndCompareR4( ComputationBuilder* builder, const Array4D& expected, tensorflow::gtl::ArraySlice arguments, ErrorSpec error) { static_assert(std::is_same::value || - std::is_same::value, - "Floating point type required when specifying an ErrorSpec"); + std::is_same::value || + std::is_same::value, + "Float or complex type required when specifying an ErrorSpec"); std::unique_ptr expected_literal = Literal::CreateR4FromArray4D(expected); ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal, diff --git a/tensorflow/compiler/xla/tests/compute_constant_test.cc b/tensorflow/compiler/xla/tests/compute_constant_test.cc index b2e9743af79d0e4658451e7a9522c338036851ba..d423c78476dde18d209b5efac9e8f77da41bfeb4 100644 --- a/tensorflow/compiler/xla/tests/compute_constant_test.cc +++ b/tensorflow/compiler/xla/tests/compute_constant_test.cc @@ -71,24 +71,27 @@ class ComputeConstantTest : public ::testing::Test { StatusOr> ComputeConstantLiteral( Client* client, const ComputationDataHandle& operand, - ComputationBuilder* builder, Layout* output_layout = nullptr) { - TF_ASSIGN_OR_RETURN(auto computed, - builder->ComputeConstant(operand, output_layout)); + ComputationBuilder* builder, Layout* output_layout = nullptr, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN(auto computed, builder->ComputeConstant( + operand, output_layout, parameters)); return std::move(computed); } template - StatusOr ComputeConstantScalar(Client* client, - const ComputationDataHandle& operand, - ComputationBuilder* builder) { - TF_ASSIGN_OR_RETURN(auto literal, - ComputeConstantLiteral(client, operand, builder)); + StatusOr ComputeConstantScalar( + Client* client, const ComputationDataHandle& operand, + ComputationBuilder* builder, + tensorflow::gtl::ArraySlice parameters = {}) { + TF_ASSIGN_OR_RETURN( + auto literal, + ComputeConstantLiteral(client, operand, builder, nullptr, parameters)); return literal->Get({}); } bool IsConstant(const ComputationDataHandle& operand, - ComputationBuilder* builder) { - StatusOr result = builder->IsConstant(operand); + ComputationBuilder* builder, int64 num_parameters = 0) { + StatusOr result = builder->IsConstant(operand, num_parameters); EXPECT_TRUE(result.ok()) << result.status(); return result.ok() ? result.ValueOrDie() : false; } @@ -138,7 +141,25 @@ TEST_F(ComputeConstantTest, ScalarRng) { } } -TEST_F(ComputeConstantTest, DirectParam) { +TEST_F(ComputeConstantTest, Param) { + for (ClientType client_type : client_types) { + Client* client = ClientOrDie(platform_, client_type); + ComputationBuilder b(client, TestName()); + auto param = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "lhs"); + auto computation = b.Add(param, b.ConstantR0(1.5f)); + + std::vector arguments; + arguments.emplace_back(*Literal::CreateR0(42.5f)); + EXPECT_TRUE(IsConstant(computation, &b, arguments.size())); + + auto value = + ComputeConstantScalar(client, computation, &b, arguments); + ASSERT_TRUE(value.ok()) << value.status(); + EXPECT_EQ(value.ValueOrDie(), 44.0f); + } +} + +TEST_F(ComputeConstantTest, DirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); @@ -152,7 +173,7 @@ TEST_F(ComputeConstantTest, DirectParam) { } } -TEST_F(ComputeConstantTest, IndirectParam) { +TEST_F(ComputeConstantTest, IndirectParamMissing) { for (ClientType client_type : client_types) { Client* client = ClientOrDie(platform_, client_type); ComputationBuilder b(client, TestName()); diff --git a/tensorflow/compiler/xla/tests/convert_test.cc b/tensorflow/compiler/xla/tests/convert_test.cc index 12b5e8426a78dc3a00794abdb892c0dcc1d15927..f66e3b57bf45fbc9f8ea786146d6fffe5d55a262 100644 --- a/tensorflow/compiler/xla/tests/convert_test.cc +++ b/tensorflow/compiler/xla/tests/convert_test.cc @@ -176,7 +176,7 @@ TEST_F(ConvertTest, ConvertMapToS32) { auto param = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "in"); b->ConvertElementType(param, S32); auto a = builder.ConstantR1({42.0f, 64.0f}); - builder.Map({a}, b->BuildAndNoteError()); + builder.Map({a}, b->BuildAndNoteError(), {0}); std::vector expected = {42, 64}; ComputeAndCompareR1(&builder, expected, {}); @@ -188,7 +188,7 @@ TEST_F(ConvertTest, ConvertMapToF32) { auto param = b->Parameter(0, ShapeUtil::MakeShape(S32, {}), "in"); b->ConvertElementType(param, F32); auto a = builder.ConstantR1({42, 64}); - builder.Map({a}, b->BuildAndNoteError()); + builder.Map({a}, b->BuildAndNoteError(), {0}); std::vector expected = {42.0f, 64.0f}; ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 83882ca75e93ee9edec8e292991b53f1af57bb62..b0a63bccbb93f226175beff2e30e2a243fdca1d3 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -39,7 +39,8 @@ class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 2, 0, 2, 2, 3, 0, 1, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -48,7 +49,8 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 2, 3, 2, 3); + ComputationBuilder::CreateConvDimensionNumbers(0, 1, 0, 1, 2, 3, 2, 3, 2, + 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -73,14 +75,18 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, ConvolutionDimensionNumbers dim_nums = ComputationBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. - int64 tmp = dim_nums.batch_dimension(); - dim_nums.set_batch_dimension(dim_nums.feature_dimension()); - dim_nums.set_feature_dimension(tmp); + int64 old_input_batch_dim = dim_nums.input_batch_dimension(); + int64 old_output_batch_dim = dim_nums.output_batch_dimension(); + dim_nums.set_input_batch_dimension(dim_nums.input_feature_dimension()); + dim_nums.set_output_batch_dimension(dim_nums.output_feature_dimension()); + dim_nums.set_input_feature_dimension(old_input_batch_dim); + dim_nums.set_output_feature_dimension(old_output_batch_dim); // Swap kernel_input_feature_dimension and kernel_output_feature_dimension. - tmp = dim_nums.kernel_input_feature_dimension(); + int64 old_kernel_input_feature_dim = + dim_nums.kernel_input_feature_dimension(); dim_nums.set_kernel_input_feature_dimension( dim_nums.kernel_output_feature_dimension()); - dim_nums.set_kernel_output_feature_dimension(tmp); + dim_nums.set_kernel_output_feature_dimension(old_kernel_input_feature_dim); builder.ConvWithGeneralDimensions(input, conv1, {1, 1}, Padding::kValid, dim_nums); diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 7d06cce0c8f82e4a1c4fb847638613594257b80f..0cc2e5fb7e655884f3334426a684dd3ce00d4052 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -418,11 +418,13 @@ XLA_TEST_F(ConvolutionTest, Convolve3D_1x4x2x3x3_2x2x2x3x3_Valid) { // Tensorflow dimension numbers for 3D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); dnums.add_spatial_dimensions(3); - dnums.set_feature_dimension(4); + dnums.set_input_feature_dimension(4); + dnums.set_output_feature_dimension(4); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.add_kernel_spatial_dimensions(2); @@ -469,10 +471,12 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { // Tensorflow dimension numbers for 2D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); dnums.add_kernel_spatial_dimensions(0); dnums.add_kernel_spatial_dimensions(1); dnums.set_kernel_input_feature_dimension(2); @@ -504,25 +508,41 @@ XLA_TEST_F(ConvolutionTest, Convolve2D_1x3x3x5_3x3x5x5_Valid) { error_spec_); } -XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) { +struct Convolve1DTestParam { + int64 input_feature; + int64 output_feature; + int64 batch; + int64 window_size; + int64 num_windows; +}; + +class Convolve1D1WindowTest + : public ConvolutionTest, + public ::testing::WithParamInterface {}; + +XLA_TEST_P(Convolve1D1WindowTest, Convolve1D1Window) { ComputationBuilder builder(client_, TestName()); - int64 output_feature = 1; - int64 input_feature = 64; - int64 batch = 1; - int64 length = 1; - std::vector input_dims = {batch, 4 + length - 1, input_feature}; - std::vector filter_dims = {4, input_feature, output_feature}; + int64 input_feature = GetParam().input_feature; + int64 output_feature = GetParam().output_feature; + int64 batch = GetParam().batch; + int64 num_windows = GetParam().num_windows; + int64 window_size = GetParam().window_size; + std::vector input_dims = {batch, window_size + num_windows - 1, + input_feature}; + std::vector filter_dims = {window_size, input_feature, output_feature}; Shape input_shape = ShapeUtil::MakeShape(F32, input_dims); Shape filter_shape = ShapeUtil::MakeShape(F32, filter_dims); { auto input = builder.Parameter(0, input_shape, "input"); auto filter = builder.Parameter(1, filter_shape, "filter"); - // Tensorflow dimension numbers for 2D convolution. + // Tensorflow dimension numbers for 1D convolution. ConvolutionDimensionNumbers dnums; - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); - dnums.set_feature_dimension(2); + dnums.set_input_feature_dimension(2); + dnums.set_output_feature_dimension(2); dnums.add_kernel_spatial_dimensions(0); dnums.set_kernel_input_feature_dimension(1); dnums.set_kernel_output_feature_dimension(2); @@ -532,28 +552,57 @@ XLA_TEST_F(ConvolutionTest, Convolve1D_Valid) { } std::vector input_elems(ShapeUtil::ElementsIn(input_shape), 1.0); - // std::iota(input_elems.begin(), input_elems.end(), 1.0f); auto input_r1 = Literal::CreateR1(input_elems); - auto input_r4 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); + auto input_r3 = input_r1->Reshape(input_dims).ConsumeValueOrDie(); std::vector filter_elems(ShapeUtil::ElementsIn(filter_shape), 1.0); - // std::iota(filter_elems.begin(), filter_elems.end(), 1.0f); auto filter_r1 = Literal::CreateR1(filter_elems); - auto filter_r4 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); + auto filter_r3 = filter_r1->Reshape(filter_dims).ConsumeValueOrDie(); - std::vector expect_elems(batch * output_feature * length, 256); + std::vector expect_elems(batch * output_feature * num_windows, + window_size * input_feature); auto expected_r1 = Literal::CreateR1(expect_elems); - auto expected_r4 = - expected_r1->Reshape({batch, length, output_feature}).ConsumeValueOrDie(); + auto expected_r3 = expected_r1->Reshape({batch, num_windows, output_feature}) + .ConsumeValueOrDie(); - auto input_literal = client_->TransferToServer(*input_r4).ConsumeValueOrDie(); + auto input_literal = client_->TransferToServer(*input_r3).ConsumeValueOrDie(); auto filter_literal = - client_->TransferToServer(*filter_r4).ConsumeValueOrDie(); - ComputeAndCompareLiteral(&builder, *expected_r4, + client_->TransferToServer(*filter_r3).ConsumeValueOrDie(); + ComputeAndCompareLiteral(&builder, *expected_r3, {input_literal.get(), filter_literal.get()}, error_spec_); } +INSTANTIATE_TEST_CASE_P( + Convolve1D1WindowTest_Instantiation, Convolve1D1WindowTest, + ::testing::Values(Convolve1DTestParam{1, 1, 1, 1, 2}, + Convolve1DTestParam{160, 1, 1, 5, 1}, + Convolve1DTestParam{24, 1, 1, 20, 1}, + Convolve1DTestParam{30, 1, 1, 20, 1}, + Convolve1DTestParam{23, 1, 1, 20, 20}, + Convolve1DTestParam{25, 1, 1, 20, 1}, + Convolve1DTestParam{24, 1, 1, 10, 5}, + Convolve1DTestParam{160, 1, 1, 10, 1}, + Convolve1DTestParam{255, 1, 1, 3, 1}, + Convolve1DTestParam{130, 1, 1, 1, 3}, + Convolve1DTestParam{64, 1, 1, 1, 1}, + Convolve1DTestParam{128, 1, 1, 1, 1}, + Convolve1DTestParam{139, 1, 1, 128, 1}, + Convolve1DTestParam{1, 10, 10, 1, 10}, + Convolve1DTestParam{1, 10, 130, 1, 2}, + Convolve1DTestParam{1, 10, 130, 1, 1}, + Convolve1DTestParam{1, 64, 64, 1, 10}, + Convolve1DTestParam{1, 65, 65, 1, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{128, 128, 128, 128, 1}, + Convolve1DTestParam{1, 128, 128, 1, 1}, + Convolve1DTestParam{2, 2, 2, 2, 1}, + Convolve1DTestParam{161, 1, 1, 10, 1}, + Convolve1DTestParam{900, 1, 1, 10, 1}, + Convolve1DTestParam{640, 3, 3, 128, 1}) + +); + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 145918db3e5e57c39054706d53bbfb7648af3143..9b36e3722b8f8a5d01c426425fdfb0c4b9ae3a16 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -974,10 +974,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1014,10 +1016,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1054,10 +1058,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); @@ -1091,10 +1097,12 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { ConvolutionDimensionNumbers dnums; // NHWC input format. - dnums.set_batch_dimension(0); + dnums.set_input_batch_dimension(0); + dnums.set_output_batch_dimension(0); dnums.add_spatial_dimensions(1); dnums.add_spatial_dimensions(2); - dnums.set_feature_dimension(3); + dnums.set_input_feature_dimension(3); + dnums.set_output_feature_dimension(3); // Tensorflow filter shape: [ H, W, inC, outC ] dnums.add_kernel_spatial_dimensions(0); diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index 342478bc744273be9deb8b750b5a6a47b7d9f91b..74f73a1ddc15be033e52b0b45f9961e5dc3a1ecb 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/cpu/custom_call_target_registry.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" @@ -31,19 +32,19 @@ limitations under the License. #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/test.h" - -extern "C" void TF_EXPORT R0F32Add2(float* out, float** in) { +namespace { +void R0F32Add2(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float*)); *out = **in + 2.0f; } -extern "C" void TF_EXPORT R2F32ReduceSum(float* out, float** in) { +void R2F32ReduceSum(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; *out = array[0] + array[1] + array[2] + array[3]; } -extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { +void Add1ToValues(float* out, float** in) { TF_ANNOTATE_MEMORY_IS_INITIALIZED(in, sizeof(float) * 4); float* array = in[0]; out[0] = array[0] + 1; @@ -51,6 +52,11 @@ extern "C" void TF_EXPORT Add1ToValues(float* out, float** in) { out[2] = array[2] + 1; out[3] = array[3] + 1; } +} // namespace + +REGISTER_CUSTOM_CALL_TARGET(R0F32Add2); +REGISTER_CUSTOM_CALL_TARGET(R2F32ReduceSum); +REGISTER_CUSTOM_CALL_TARGET(Add1ToValues); namespace xla { namespace { diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index 224aa57899d04eb8309b2337bb8fc936a81d350f..cf089d748dcd4f5db637ff9087c5fbc504c82572 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -347,7 +347,7 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTF) { TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); } -TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF32MajorToMinorTT) { constexpr bool kLhsRowMajor = true; constexpr bool kRhsRowMajor = true; TestNonsquareMatrixDot(kLhsRowMajor, kRhsRowMajor); @@ -357,7 +357,11 @@ XLA_TEST_F(DotOperationTest, NonsquareMatrixDotF64) { TestNonsquareMatrixDot(); } -TEST_F(DotOperationTest, ConcurrentMatMul) { +XLA_TEST_F(DotOperationTest, NonsquareMatrixDotC64) { + TestNonsquareMatrixDot(); +} + +XLA_TEST_F(DotOperationTest, ConcurrentMatMul) { ComputationBuilder builder(client_, TestName()); auto matrix1 = builder.ConstantR2({{1.0, 2.0}, {3.0, 4.0}}); auto matrix2 = builder.ConstantR2({{5.0, 6.0}, {7.0, 8.0}}); diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc index b32c9e160408d28ee679bd445db9a03aec86ffff..19252f50f25eee42e4e492b7f0e2ec3960c62126 100644 --- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc +++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc @@ -555,8 +555,7 @@ void BM_DynamicSlice(int num_iters) { auto computation = builder.Build().ConsumeValueOrDie(); // Initialize and transfer parameter buffer. - auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(start_indices_shape, - &allocator, 0) + auto buffer = ScopedShapedBuffer::Allocate(start_indices_shape, &allocator, 0) .ConsumeValueOrDie(); auto start_indices_literal = Literal::CreateR1({0, 1, 2, 3}); diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc index 2be409561ab3e23d9ea2e49aac381a90395380d0..a8f6488996087b57e3121ce2c7de918070950c72 100644 --- a/tensorflow/compiler/xla/tests/fusion_test.cc +++ b/tensorflow/compiler/xla/tests/fusion_test.cc @@ -17,8 +17,12 @@ limitations under the License. #include #include #include +#include #include +#define EIGEN_USE_THREADS + +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" #include "tensorflow/compiler/xla/array2d.h" #include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/computation.h" @@ -37,6 +41,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/literal_test_util.h" #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/protobuf.h" @@ -250,6 +255,42 @@ XLA_TEST_F(FusionTest, Parameter) { ErrorSpec(1e-4)); } +XLA_TEST_F(FusionTest, RandomizedParallelPartition) { + // Tests parallel partitioning of a fusion instruction. + // Create shape with random outer dimension size to generate random parallel + // partition counts for each test run. + const int seed = tensorflow::testing::RandomSeed(); + LOG(INFO) << "RandomizedParallelPartition seed: " << seed; + std::mt19937 generator(seed); + std::uniform_int_distribution distribution(128, 1024); + const int64 rand_dim0_size = distribution(generator); + const int64 dim1_size = 1024; + Shape shape = + ShapeUtil::MakeShapeWithLayout(F32, {rand_dim0_size, dim1_size}, {1, 0}); + // Build simple fusion computation: y = x^2 (elementwise). + auto builder = HloComputation::Builder(TestName()); + auto hlo_module = CreateNewModule(); + + auto two = builder.AddInstruction( + HloInstruction::CreateConstant(Literal::CreateR0(2.0))); + auto x = + builder.AddInstruction(HloInstruction::CreateBroadcast(shape, two, {})); + auto y = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kMultiply, x, x)); + + hlo_module->AddEntryComputation(builder.Build()) + ->CreateFusionInstruction(/*instructions_to_fuse=*/{y, x, two}, + HloInstruction::FusionKind::kLoop); + // Compute result. + auto result = ExecuteAndTransfer(std::move(hlo_module), {}); + // Every element of result should be y = x^2 = 4.0. + for (int i = 0; i < rand_dim0_size; ++i) { + for (int j = 0; j < dim1_size; ++j) { + EXPECT_EQ(4.0, result->Get({i, j})); + } + } +} + XLA_TEST_F(FusionTest, BroadcastIntoBinaryOp) { auto builder = HloComputation::Builder(TestName()); auto hlo_module = CreateNewModule(); @@ -655,10 +696,10 @@ XLA_TEST_F(FusionTest, SharedConstant) { HloComputation* entry_comp = hlo_module->entry_computation(); // entry computation contains the constant(0) and the fusion - EXPECT_EQ(entry_comp->instructions().size(), 2); + EXPECT_EQ(entry_comp->instruction_count(), 2); // fused instruction contains the constant(2), the parameter, and 4 adds - EXPECT_EQ(entry_comp->root_instruction()->fused_instructions().size(), 6); + EXPECT_EQ(entry_comp->root_instruction()->fused_instruction_count(), 6); LiteralTestUtil::ExpectEqual(*Literal::CreateR1({8}), *ExecuteAndTransfer(std::move(hlo_module), {})); @@ -722,47 +763,104 @@ void BM_ParallelFusion(int num_iters) { auto executors = PlatformUtil::GetStreamExecutors(platform).ValueOrDie(); StreamExecutorMemoryAllocator allocator(platform, executors); - const int64 intra_op_parallelism_threads = 16; + const int64 intra_op_parallelism_threads = 24; xla::LocalClientOptions client_options; client_options.set_platform(platform); client_options.set_intra_op_parallelism_threads(intra_op_parallelism_threads); auto client = ClientLibrary::GetOrCreateLocalClient(client_options).ValueOrDie(); - const int64 dim_size = 1024; - // Create a simple fusable elementwise computation. + auto* transfer_manager = + TransferManager::GetForPlatform(platform).ValueOrDie(); + int device_ordinal = client->default_device_ordinal(); + + // Computation shape parameters. + const int64 param0_dim0 = 1024; + const int64 param0_dim1 = 1024; + const int64 param1_dim0 = 1024; + const int64 param1_dim1 = 1024; + const int64 param2_dim0 = 1024; + const int64 param2_dim1 = 1024; + + // Create computation. ComputationBuilder builder(client, "ParallelFusion"); - Shape input_shape = ShapeUtil::MakeShape(F32, {dim_size, dim_size}); - auto input0 = builder.Broadcast(builder.ConstantR0(1.5f), - AsInt64Slice(input_shape.dimensions())); - auto input1 = builder.Broadcast(builder.ConstantR0(2.0f), - AsInt64Slice(input_shape.dimensions())); - auto input2 = builder.Broadcast(builder.ConstantR0(3.0f), - AsInt64Slice(input_shape.dimensions())); - auto x = builder.Mul(input0, input1); - auto y = builder.Add(x, input2); + Shape shape0 = ShapeUtil::MakeShape(F32, {param0_dim0, param0_dim1}); + auto param0 = builder.Parameter(0, shape0, "param0"); + Shape shape1 = ShapeUtil::MakeShape(F32, {param1_dim0, param1_dim1}); + auto param1 = builder.Parameter(1, shape1, "param1"); + Shape shape2 = ShapeUtil::MakeShape(F32, {param2_dim0, param2_dim1}); + auto param2 = builder.Parameter(2, shape2, "param2"); + + auto x = builder.Mul(param0, param1); + auto y = builder.Add(x, param2); auto computation = builder.Build().ConsumeValueOrDie(); + // Transfer literals to device. + auto buffer0 = + ScopedShapedBuffer::Allocate(shape0, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param0_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param0_literal, buffer0->mutable_buffer({}))); + + auto buffer1 = + ScopedShapedBuffer::Allocate(shape1, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param1_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param1_literal, buffer1->mutable_buffer({}))); + + auto buffer2 = + ScopedShapedBuffer::Allocate(shape2, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); + auto param2_literal = + Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1); + ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( + executors[device_ordinal], *param2_literal, buffer2->mutable_buffer({}))); + + // Build executable. std::unique_ptr executable = - client->Compile(computation, {}, ExecutableBuildOptions()) + client + ->Compile(computation, + {&buffer0->shape(), &buffer1->shape(), &buffer2->shape()}, + ExecutableBuildOptions()) .ConsumeValueOrDie(); - // Run some warm-up executions. + se::Stream stream(executors[client->default_device_ordinal()]); + stream.Init(); + + // Initialize thread pool. + tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "XLAEigen", + intra_op_parallelism_threads); + tensorflow::EigenThreadPoolWrapper tp(&pool); + Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); + + // Initialize ExecutableRunOptions. ExecutableRunOptions options; - options.set_allocator(&allocator); + options.set_allocator(&allocator).set_stream(&stream); + options.set_intra_op_thread_pool(&device); + + // Run some warm-up executions. const int kWarmups = 2; for (int i = 0; i < kWarmups; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } // Run benchmark. - tensorflow::testing::BytesProcessed(static_cast(num_iters) * dim_size * - dim_size * sizeof(float)); + const int64 total_bytes = param0_dim0 * param0_dim0 + + param1_dim0 * param1_dim0 + + param2_dim0 * param2_dim0; + tensorflow::testing::BytesProcessed(static_cast(num_iters) * + total_bytes * sizeof(float)); tensorflow::testing::UseRealTime(); tensorflow::testing::StartTiming(); for (int i = 0; i < num_iters; ++i) { - auto result = executable->Run({}, options); + auto result = + executable->Run({buffer0.get(), buffer1.get(), buffer2.get()}, options); ASSERT_TRUE(result.ok()); } } diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.cc b/tensorflow/compiler/xla/tests/hlo_test_base.cc index 26513d6ce8e0b8896e9f9838ecf28f1ed5bbb383..d73c05ff92578209143e0679558848160cae99bd 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.cc +++ b/tensorflow/compiler/xla/tests/hlo_test_base.cc @@ -19,24 +19,9 @@ limitations under the License. #include #include -#define EIGEN_USE_THREADS - -#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" -#include "tensorflow/compiler/xla/layout_util.h" #include "tensorflow/compiler/xla/legacy_flags/debug_options_flags.h" #include "tensorflow/compiler/xla/ptr_util.h" -#include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/computation_layout.h" -#include "tensorflow/compiler/xla/service/executable.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_execution_profile.h" -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/transfer_manager.h" -#include "tensorflow/compiler/xla/shape_layout.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" -#include "tensorflow/core/common_runtime/eigen_thread_pool.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" @@ -45,22 +30,6 @@ namespace se = ::perftools::gputools; namespace xla { -// Define this in .cc file to avoid having to include eigen or forward declare -// these types in the header. -struct HloTestBase::EigenThreadPoolWrapper { - std::unique_ptr pool; - std::unique_ptr device; -}; - -HloTestBase::HloTestBase() {} - -HloTestBase::~HloTestBase() { - // Deallocate all the memory allocated during the tests. - for (auto& allocation : allocations_) { - backend().default_stream_executor()->Deallocate(&allocation); - } -} - /* static */ std::unique_ptr HloTestBase::CreateNewModule() { HloModuleConfig config; @@ -80,98 +49,25 @@ StatusOr HloTestBase::Execute( tensorflow::gtl::ArraySlice arguments, Shape* result_shape) { - TF_ASSIGN_OR_RETURN( - std::unique_ptr executable, - backend().compiler()->Compile(std::move(module), - backend().default_stream_executor())); - - se::Stream stream(backend().default_stream_executor()); - stream.Init(); - - ExecutableRunOptions run_options; - run_options.set_stream(&stream); - run_options.set_allocator(backend().memory_allocator()); - run_options.set_inter_op_thread_pool(backend().inter_op_thread_pool()); - run_options.set_intra_op_thread_pool( - backend().eigen_intra_op_thread_pool_device()); - - HloExecutionProfile hlo_execution_profile; - ServiceExecutableRunOptions service_run_options( - run_options, backend().StreamBorrower(), - backend().inter_op_thread_pool()); - TF_ASSIGN_OR_RETURN( - se::DeviceMemoryBase result, - executable->ExecuteOnStream(&service_run_options, arguments, - &hlo_execution_profile)); - TF_RET_CHECK(stream.BlockHostUntilDone()); - - allocations_.push_back(result); - - *result_shape = executable->result_shape(); - - if (ShapeUtil::IsTuple(*result_shape)) { - // We must record element buffers of tuples as well to avoid leaks. - DCHECK(!ShapeUtil::IsNestedTuple(*result_shape)); - TF_ASSIGN_OR_RETURN( - std::vector element_buffers, - backend().transfer_manager()->ShallowCopyTupleFromDevice( - backend().default_stream_executor(), result, *result_shape)); - - // A tuple may contain the same buffer in more than one element. Keep track - // of the buffers already added to avoid duplicates in allocations_. - std::set added_opaques; - for (auto element_buffer : element_buffers) { - if (added_opaques.count(element_buffer.opaque()) == 0) { - CHECK(element_buffer.opaque() != nullptr); - added_opaques.insert(element_buffer.opaque()); - allocations_.push_back(element_buffer); - } - } - } - - return result; + return runner_.Execute(std::move(module), arguments, result_shape); } se::DeviceMemoryBase HloTestBase::TransferToDevice(const Literal& literal) { - // Allocate memory on the device using the stream executor. - int64 allocation_size = - backend().transfer_manager()->GetByteSizeRequirement(literal.shape()); - se::DeviceMemoryBase allocation = - backend().default_stream_executor()->AllocateArray( - allocation_size); - allocations_.push_back(allocation); - - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralToDevice( - backend().default_stream_executor(), literal, &allocation)); - - return allocation; + return runner_.TransferToDevice(literal).ValueOrDie(); } std::unique_ptr HloTestBase::TransferFromDevice( const Shape& shape, se::DeviceMemoryBase device_base) { - auto literal = MakeUnique(); - TF_CHECK_OK(backend().transfer_manager()->TransferLiteralFromDevice( - backend().default_stream_executor(), device_base, shape, shape, - literal.get())); - return literal; + return runner_.TransferFromDevice(shape, device_base).ValueOrDie(); } std::unique_ptr HloTestBase::ExecuteAndTransfer( std::unique_ptr module, tensorflow::gtl::ArraySlice arguments) { - Shape result_shape; - se::DeviceMemoryBase device_base = - Execute(std::move(module), arguments, &result_shape).ValueOrDie(); - return TransferFromDevice(result_shape, device_base); + return runner_.ExecuteAndTransfer(std::move(module), arguments).ValueOrDie(); } -Backend& HloTestBase::backend() { - if (!backend_) { - backend_ = Backend::CreateDefaultBackend().ConsumeValueOrDie(); - VLOG(1) << "executing on platform " << backend().platform()->Name(); - } - return *backend_; -} +Backend& HloTestBase::backend() { return runner_.backend(); } /* static */ string HloTestBase::TestName() { diff --git a/tensorflow/compiler/xla/tests/hlo_test_base.h b/tensorflow/compiler/xla/tests/hlo_test_base.h index 275f1f5c7baa11245186d119f5b38b4d02b84566..7f068dce36be3546298de2f06bf6d33446d07ca2 100644 --- a/tensorflow/compiler/xla/tests/hlo_test_base.h +++ b/tensorflow/compiler/xla/tests/hlo_test_base.h @@ -21,12 +21,12 @@ limitations under the License. #include #include "tensorflow/compiler/xla/service/backend.h" -#include "tensorflow/compiler/xla/service/compiler.h" -#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/computation_layout.h" #include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/service/hlo_runner.h" +#include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/tests/literal_test_util.h" -#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" @@ -39,10 +39,9 @@ namespace xla { // building a graph of HLO instructions to run. class HloTestBase : public ::testing::Test { protected: - struct EigenThreadPoolWrapper; - HloTestBase(); + HloTestBase() {} - ~HloTestBase() override; + ~HloTestBase() override {} // Creates a new HLO module for a test. The module created will have // TestName() for its name; it will also automatically populate its debug @@ -102,23 +101,12 @@ class HloTestBase : public ::testing::Test { static string TestName(); - // Creates (if necessary) and returns the default backend. If creation fails, - // crashes the program. - // - // This creates the backend lazily so it's possible to instantiate an - // HloTestBase in a program without any backends linked in. + // Returns the backend owned by the HloRunner. Backend& backend(); - // This vector contains handles of all the device memory allocations performed - // by the test. These are deallocated on destruction of the test object. - std::vector allocations_; + HloRunner runner_; ErrorSpec error_spec_{0.0001}; - - std::unique_ptr thread_pool_wrapper_; - - private: - std::unique_ptr backend_; // Lazily populated. Access via backend(). }; } // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..31060b9e80fcd50aefdedca27c70ec8a9b8be743 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.cc @@ -0,0 +1,69 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h" + +#include "tensorflow/compiler/xla/service/hlo_verifier.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/status_macros.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { + +/*static*/ int64 HloVerifiedTestBase::DefaultShapeSize(const Shape& shape) { + constexpr int64 kPointerSize = sizeof(void*); + if (ShapeUtil::IsOpaque(shape)) { + return kPointerSize; + } + return ShapeUtil::ByteSizeOf(shape, kPointerSize); +} + +HloVerifiedTestBase::HloVerifiedTestBase() : shape_size_fn_(DefaultShapeSize) {} + +HloVerifiedTestBase::~HloVerifiedTestBase() { + // We can't call the ASSERT or EXPECT test macros in destructors, so we + // perform HLO verification in TearDown, and use the CHECK here to ensure + // users don't accidentally override the verification. + CHECK(tear_down_called_) + << "TearDown was never called; subclasses of HloVerifiedTestBase that " + << "override TearDown must call the superclass TearDown."; +} + +void HloVerifiedTestBase::TearDown() { + EXPECT_FALSE(tear_down_called_) + << "TearDown called more than once; it should be called exactly once."; + tear_down_called_ = true; + if (module_) { + HloVerifier verifier(shape_size_fn_); + xla::StatusOr mutated = verifier.Run(module_.get()); + if (!mutated.ok()) { + ADD_FAILURE() << "HloVerifier failed: " << mutated.status(); + } else { + EXPECT_FALSE(mutated.ValueOrDie()) + << "HloVerifier should never mutate the HloModule"; + } + } + HloTestBase::TearDown(); +} + +HloModule& HloVerifiedTestBase::module() { + if (!module_) { + module_ = CreateNewModule(); + } + return *module_; +} + +} // namespace xla diff --git a/tensorflow/compiler/xla/tests/hlo_verified_test_base.h b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b3d6b5af3b46f932707abf309669d23c327d1334 --- /dev/null +++ b/tensorflow/compiler/xla/tests/hlo_verified_test_base.h @@ -0,0 +1,63 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/tests/hlo_test_base.h" + +namespace xla { + +// A base class for HLO tests that stores a default HloModule, and automatically +// performs verification on that module on tear-down. +class HloVerifiedTestBase : public HloTestBase { + public: + // Returns the size in bytes of the given shape, using a default pointer size. + static int64 DefaultShapeSize(const Shape& shape); + + protected: + HloVerifiedTestBase(); + ~HloVerifiedTestBase() override; + + // Performs verification on the default HloModule returned by module(). + // Automatically called by the testing framework for each test. + // + // REQUIRED: subclasses that override TearDown() must call this explicitly. + void TearDown() override; + + // Returns the default HloModule, lazily creating it if necessary via + // HloTestBase::CreateNewModule(). + HloModule& module(); + + // Sets the shape-size function used during hlo verification. If this isn't + // called, DefaultShapeSize is used instead. + void SetShapeSizeFn(std::function shape_size_fn) { + shape_size_fn_ = std::move(shape_size_fn); + } + + private: + std::unique_ptr module_; // Lazily populated. Access via module(). + std::function shape_size_fn_; + bool tear_down_called_ = false; +}; + +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TESTS_HLO_VERIFIED_TEST_BASE_H_ diff --git a/tensorflow/compiler/xla/tests/literal_test_util.cc b/tensorflow/compiler/xla/tests/literal_test_util.cc index 4d8b50fbbf715e8d491667ecb4f4f336ef2d8a68..95a52ecd2f5cfc97ec1ccba7d1b7ca6257a8267e 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.cc +++ b/tensorflow/compiler/xla/tests/literal_test_util.cc @@ -39,28 +39,60 @@ limitations under the License. namespace xla { -/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, - const Shape& actual) { - ASSERT_EQ(ShapeUtil::IsTuple(expected), ShapeUtil::IsTuple(actual)); +/* static */ ::testing::AssertionResult LiteralTestUtil::EqualShapes( + const Shape& expected, const Shape& actual) { + if (ShapeUtil::IsTuple(expected) != ShapeUtil::IsTuple(actual)) { + return ::testing::AssertionFailure() + << "tupleness-mismatch! want: " << ShapeUtil::HumanString(expected) + << " got: " << ShapeUtil::HumanString(actual); + } if (ShapeUtil::IsTuple(expected)) { - ASSERT_EQ(ShapeUtil::TupleElementCount(expected), - ShapeUtil::TupleElementCount(actual)); + if (ShapeUtil::TupleElementCount(expected) != + ShapeUtil::TupleElementCount(actual)) { + return ::testing::AssertionFailure() + << "want tuple element count: " + << ShapeUtil::TupleElementCount(expected) + << " got tuple element count: " + << ShapeUtil::TupleElementCount(actual); + } for (int i = 0; i < expected.tuple_shapes_size(); ++i) { - AssertEqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + ::testing::AssertionResult result = + EqualShapes(expected.tuple_shapes(i), actual.tuple_shapes(i)); + if (!result) { + return result; + } } } else { - ASSERT_EQ(ShapeUtil::Rank(expected), ShapeUtil::Rank(actual)); - ASSERT_EQ(expected.element_type(), actual.element_type()) - << PrimitiveType_Name(expected.element_type()) << " vs " - << PrimitiveType_Name(actual.element_type()); - ASSERT_EQ(expected.dimensions_size(), actual.dimensions_size()); + if (ShapeUtil::Rank(expected) != ShapeUtil::Rank(actual)) { + return ::testing::AssertionFailure() + << "want rank of: " << ShapeUtil::HumanString(expected) + << " got rank of: " << ShapeUtil::HumanString(actual); + } + if (expected.element_type() != actual.element_type()) { + return ::testing::AssertionFailure() + << PrimitiveType_Name(expected.element_type()) << " vs " + << PrimitiveType_Name(actual.element_type()); + } + if (expected.dimensions_size() != actual.dimensions_size()) { + return ::testing::AssertionFailure() + << "want dimensions_size " << expected.dimensions_size() + << " got dimensions_size " << actual.dimensions_size(); + } for (int i = 0; i < expected.dimensions_size(); ++i) { - ASSERT_EQ(expected.dimensions(i), actual.dimensions(i)) - << "mismatch in dimension #" << i - << " expected: " << ShapeUtil::HumanString(expected) - << " actual: " << ShapeUtil::HumanString(actual); + if (expected.dimensions(i) != actual.dimensions(i)) { + return ::testing::AssertionFailure() + << "mismatch in dimension #" << i + << " expected: " << ShapeUtil::HumanString(expected) + << " actual: " << ShapeUtil::HumanString(actual); + } } } + return ::testing::AssertionSuccess(); +} + +/* static */ void LiteralTestUtil::AssertEqualShapes(const Shape& expected, + const Shape& actual) { + ASSERT_TRUE(EqualShapes(expected, actual)); } /* static */ void LiteralTestUtil::AssertEqualShapesAndLayouts( @@ -124,6 +156,15 @@ template <> ::testing::AssertionResult CompareEqual(double lhs, double rhs) { return CompareFloatsBitwiseEqual(lhs, rhs); } +template <> +::testing::AssertionResult CompareEqual(complex64 lhs, + complex64 rhs) { + auto res = CompareEqual(lhs.real(), rhs.real()); + if (!res) { + return res; + } + return CompareEqual(lhs.imag(), rhs.imag()); +} // A recursive function which iterates through every index of expected and // actual literal and compares their values elementwise. Returns true if all @@ -203,6 +244,9 @@ bool ExpectLiteralsEqual(const Literal& expected, const Literal& actual, case F64: match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); break; + case C64: + match = ExpectLiteralsEqual(expected, actual, &multi_index, 0); + break; case TUPLE: { bool tuple_match = true; for (int i = 0; i < actual.tuple_literals_size(); ++i) { @@ -263,7 +307,14 @@ class NearComparator { VLOG(1) << "actual:"; XLA_VLOG_LINES(1, actual.ToString()); - LiteralTestUtil::AssertEqualShapes(expected.shape(), actual.shape()); + // If the shapes mismatch, we simply fail the expectation instead of + // printing out data, as it's a type error rather than a value error. + ::testing::AssertionResult equal_shapes = + LiteralTestUtil::EqualShapes(expected.shape(), actual.shape()); + if (!equal_shapes) { + EXPECT_TRUE(equal_shapes); + return false; + } // Set up members used during the comparison. num_miscompares_ = 0; @@ -286,6 +337,9 @@ class NearComparator { case F64: ExpectLiteralsNear(expected, actual, 0); break; + case C64: + ExpectLiteralsNear(expected, actual, 0); + break; default: LOG(FATAL) << "Unsupported primitive type in near comparator: " << PrimitiveType_Name(expected.shape().element_type()) @@ -326,6 +380,19 @@ class NearComparator { } private: + template + bool NanMismatch(NativeT lhs, NativeT rhs) { + return std::isnan(lhs) != std::isnan(rhs); + } + + template + void ExpectNear(NativeT expected, NativeT actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected, actual, error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + } + // EXPECTs that the two given scalar values are within the error bound. Keeps // track of how many mismatches have occurred to keep the size of the output // manageable. @@ -351,7 +418,7 @@ class NearComparator { "index %s abs_diff %f rel_err %f", LiteralTestUtil::MultiIndexAsString(multi_index_).c_str(), abs_diff, rel_err); - bool nan_mismatch = std::isnan(actual) != std::isnan(expected); + bool nan_mismatch = NanMismatch(expected, actual); bool mismatch = (nan_mismatch || (abs_diff >= error_.abs && rel_err >= error_.rel)); if (mismatch) { @@ -359,11 +426,12 @@ class NearComparator { abs_expected_miscompare_sum_ += std::abs(expected); const int64 kMaxFailures = 2; if (num_miscompares_ < kMaxFailures) { - EXPECT_NEAR(expected, actual, error_.abs) - << "mismatch at index " + ::testing::Message msg; + msg << "mismatch at index " << LiteralTestUtil::MultiIndexAsString(multi_index_) << " abs diff " << abs_diff << " rel err " << rel_err << " failure #" << num_miscompares_; + ExpectNear(expected, actual, msg); } else if (num_miscompares_ == kMaxFailures) { LOG(ERROR) << "reached max 'loud' failure count; silently proceeding..."; @@ -431,6 +499,23 @@ class NearComparator { std::vector max_abs_multi_index_; }; +template <> +bool NearComparator::NanMismatch(complex64 lhs, complex64 rhs) { + return std::isnan(lhs.real()) != std::isnan(rhs.real()) || + std::isnan(lhs.imag()) != std::isnan(rhs.imag()); +} + +template <> +void NearComparator::ExpectNear(complex64 expected, complex64 actual, + const ::testing::Message& message) { + EXPECT_NEAR(expected.real(), actual.real(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; + EXPECT_NEAR(expected.imag(), actual.imag(), error_.abs) + << "expected:\n " << expected << "\n\tvs actual:\n " << actual << "\n" + << message; +} + } // namespace /* static */ ::testing::AssertionResult LiteralTestUtil::Near( diff --git a/tensorflow/compiler/xla/tests/literal_test_util.h b/tensorflow/compiler/xla/tests/literal_test_util.h index f645c4e8dcda73806a4204876716b93aa5fb7185..467d44b857b74d2a38e9b3f8a32a9b1d39a4a10d 100644 --- a/tensorflow/compiler/xla/tests/literal_test_util.h +++ b/tensorflow/compiler/xla/tests/literal_test_util.h @@ -50,6 +50,8 @@ class LiteralTestUtil { public: // Asserts that the given shapes have the same rank, dimension sizes, and // primitive types. + static ::testing::AssertionResult EqualShapes(const Shape& expected, + const Shape& actual); static void AssertEqualShapes(const Shape& expected, const Shape& actual); // Asserts that the provided shapes are equal as defined in AssertEqualShapes diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc index 6897f0291ab2244638a59bed6be06444bf7d1d98..3d30ceeaf1b0369b6fdc0cd9620c04aae287941c 100644 --- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc @@ -44,8 +44,8 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) { TestAllocator* allocator = GetOrCreateAllocator(local_client_->platform()); - auto x_array = LiteralToScopedShapedBuffer( - *Literal::CreateR1({0.0f, 1.0f, 2.0f})); + auto x_array = + LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); int64 allocation_count_before = allocator_->allocation_count(); diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc index ef2592e2923fe711f8d7550f3c72dc2b7ed4a761..329b53012f58c8d084cc05f9a567a8aa432c4a3a 100644 --- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc +++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tests/test_macros.h" #include "tensorflow/compiler/xla/tests/test_utils.h" #include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/test.h" @@ -71,7 +72,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) { auto y = builder.ConstantR0(123.0f); builder.Add(x, y); - auto x_value = LiteralToScopedShapedBuffer(*Literal::CreateR0(42.0f)); + auto x_value = LiteralToShapedBuffer(*Literal::CreateR0(42.0f)); std::unique_ptr result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_value.get()}); @@ -85,7 +86,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) { auto y = builder.ConstantR1({}); builder.Add(x, y); - auto x_array = LiteralToScopedShapedBuffer(*Literal::CreateR1({})); + auto x_array = LiteralToShapedBuffer(*Literal::CreateR1({})); std::unique_ptr result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); @@ -99,8 +100,8 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) { auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); - auto x_array = LiteralToScopedShapedBuffer( - *Literal::CreateR1({0.0f, 1.0f, 2.0f})); + auto x_array = + LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); std::unique_ptr result = ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {x_array.get()}); @@ -114,8 +115,8 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) { auto y = builder.ConstantR1({2.0f, 3.0f, 4.0f}); builder.Add(x, y); - auto x_array = LiteralToScopedShapedBuffer( - *Literal::CreateR1({0.0f, 1.0f, 2.0f})); + auto x_array = + LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); ExecutionProfile profile; std::unique_ptr result = ExecuteLocallyOrDie( builder.Build().ValueOrDie(), {x_array.get()}, @@ -135,14 +136,14 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) { auto computation = builder.Build().ConsumeValueOrDie(); // Create x as a col-major array. - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *test_utils::CreateR2LiteralWithLayout({{1.0f, 2.0f}, {3.0f, 4.0f}}, /*minor_to_major=*/{0, 1})); EXPECT_TRUE(LayoutUtil::Equal(x_array->shape().layout(), LayoutUtil::MakeLayout({0, 1}))); // Create y as a row-major array. - auto y_array = LiteralToScopedShapedBuffer( + auto y_array = LiteralToShapedBuffer( *test_utils::CreateR2LiteralWithLayout({{10.0f, 20.0f}, {30.0f, 40.0f}}, /*minor_to_major=*/{1, 0})); EXPECT_TRUE(LayoutUtil::Equal(y_array->shape().layout(), @@ -169,9 +170,9 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) { builder.Add(x, y); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); - auto y_array = LiteralToScopedShapedBuffer( + auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); // Run with col-major result layout. @@ -206,9 +207,9 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) { builder.Tuple({x, y, x}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); - auto y_array = LiteralToScopedShapedBuffer( + auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); std::unique_ptr result = @@ -234,9 +235,9 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) { builder.Tuple({inner_tuple, x}); auto computation = builder.Build().ConsumeValueOrDie(); - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); - auto y_array = LiteralToScopedShapedBuffer( + auto y_array = LiteralToShapedBuffer( *Literal::CreateR2({{10.0f, 20.0f}, {30.0f, 40.0f}})); std::unique_ptr result = @@ -264,7 +265,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y"); builder.Tuple({x, y}); - auto array = LiteralToScopedShapedBuffer( + auto array = LiteralToShapedBuffer( *Literal::CreateR2({{1.0f, 2.0f}, {3.0f, 4.0f}})); ExecutableBuildOptions options = DefaultExecutableBuildOptions(); @@ -285,6 +286,283 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) { result_literal->tuple_literals(1)); } +XLA_TEST_F(LocalClientExecuteTest, TupleArguments) { + const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); + const Shape vector_shape = ShapeUtil::MakeShape(F32, {3}); + + const Shape tuple_shape0 = + ShapeUtil::MakeTupleShape({array_shape, vector_shape}); + const Shape tuple_shape1 = + ShapeUtil::MakeTupleShape({vector_shape, array_shape}); + + // Computation adds the respective array and vector elements from each tuple + // argument and returns the results as a tuple. + ComputationBuilder builder(local_client_, TestName()); + auto x = builder.Parameter(0, tuple_shape0, "x"); + auto y = builder.Parameter(1, tuple_shape1, "y"); + auto x_0 = builder.GetTupleElement(x, 0); + auto x_1 = builder.GetTupleElement(x, 1); + auto y_0 = builder.GetTupleElement(y, 0); + auto y_1 = builder.GetTupleElement(y, 1); + auto array_sum = builder.Add(x_0, y_1); + auto vector_diff = builder.Sub(x_1, y_0); + builder.Tuple({array_sum, vector_diff}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto x_literal = Literal::MakeTuple( + {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + Literal::CreateR1({42.0, 75.0, 123.0}).get()}); + auto y_literal = Literal::MakeTuple( + {Literal::CreateR1({2.0, 4.0, 6.0}).get(), + Literal::CreateR2({{55.0, 44.0}, {33.0, 22.0}}).get()}); + + auto x_buffer = LiteralToShapedBuffer(*x_literal); + auto y_buffer = LiteralToShapedBuffer(*y_literal); + + std::unique_ptr result = + ExecuteLocallyOrDie(computation, {x_buffer.get(), y_buffer.get()}); + + EXPECT_TRUE(ShapeUtil::IsTuple(result->shape())); + EXPECT_EQ(2, ShapeUtil::TupleElementCount(result->shape())); + + std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR2Equal({{56.0f, 46.0f}, {36.0f, 26.0f}}, + result_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR1Equal({40.0f, 71.0f, 117.0f}, + result_literal->tuple_literals(1)); +} + +XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) { + const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); + const Shape vector_shape = ShapeUtil::MakeShape(F32, {3}); + + const Shape inner_tuple_shape = + ShapeUtil::MakeTupleShape({array_shape, vector_shape}); + const Shape nested_tuple_shape = + ShapeUtil::MakeTupleShape({inner_tuple_shape, vector_shape}); + + // Computation negates the array element and sums the two vector elements in + // the nested tuple. The resulting array and vector are returned as a tuple. + ComputationBuilder builder(local_client_, TestName()); + auto param = builder.Parameter(0, nested_tuple_shape, "param"); + auto inner_tuple = builder.GetTupleElement(param, 0); + auto inner_array = builder.GetTupleElement(inner_tuple, 0); + auto inner_vector = builder.GetTupleElement(inner_tuple, 1); + auto outer_vector = builder.GetTupleElement(param, 1); + + auto negate_array = builder.Neg(inner_array); + auto vector_sum = builder.Add(inner_vector, outer_vector); + builder.Tuple({negate_array, vector_sum}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto arg_literal = Literal::MakeTuple( + {Literal::MakeTuple( + {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + Literal::CreateR1({42.0, 75.0, 123.0}).get()}) + .get(), + Literal::CreateR1({222.0, -2.0, 10.0}).get()}); + auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + + std::unique_ptr result = + ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + + std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4}}, + result_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR1Equal({264.0, 73.0, 133.0}, + result_literal->tuple_literals(1)); +} + +XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) { + // Construct a computation which takes and returns the same shape (a + // tuple). Feed the result of the computation back into the input. This + // provides additional verification that the returned tuple is properly + // constructed. + const Shape array_shape = ShapeUtil::MakeShape(F32, {2, 2}); + const Shape tuple_shape = + ShapeUtil::MakeTupleShape({array_shape, array_shape}); + + ComputationBuilder builder(local_client_, TestName()); + auto param = builder.Parameter(0, tuple_shape, "param"); + auto element_0 = builder.GetTupleElement(param, 0); + auto element_1 = builder.GetTupleElement(param, 1); + builder.Tuple({builder.Neg(element_0), builder.Add(element_1, element_1)}); + auto computation = builder.Build().ConsumeValueOrDie(); + + auto arg_literal = Literal::MakeTuple( + {Literal::CreateR2({{1.0, 2.0}, {3.0, 4.0}}).get(), + Literal::CreateR2({{11.0, 3.0}, {4.0, 5.0}}).get()}); + auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + + std::unique_ptr result_0 = + ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + std::unique_ptr result_0_literal = ShapedBufferToLiteral(*result_0); + LiteralTestUtil::ExpectR2Equal({{-1.0, -2.0}, {-3.0, -4.0}}, + result_0_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal({{22.0, 6.0}, {8.0, 10}}, + result_0_literal->tuple_literals(1)); + + std::unique_ptr result_1 = + ExecuteLocallyOrDie(computation, {result_0.get()}); + std::unique_ptr result_1_literal = ShapedBufferToLiteral(*result_1); + LiteralTestUtil::ExpectR2Equal({{1.0, 2.0}, {3.0, 4.0}}, + result_1_literal->tuple_literals(0)); + LiteralTestUtil::ExpectR2Equal({{44.0, 12.0}, {16.0, 20}}, + result_1_literal->tuple_literals(1)); +} + +XLA_TEST_F(LocalClientExecuteTest, LargeTuple) { + // Construct a computation which takes a tuple parameter with a very large + // number of elements. + + // A larger number of elements would make for a better, more strenuous test, + // but: + // TODO(b/66959878): On cpu a large number of elements results in long + // compilation time. + // TODO(b/66954197): On gpu a large number of elements OOMs. + const int kElementCount = 100; + + // Each element is a 2-element vector. + const Shape element_shape = ShapeUtil::MakeShape(F32, {2}); + std::vector element_shapes(kElementCount, element_shape); + const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); + + ComputationBuilder builder(local_client_, TestName()); + auto param = builder.Parameter(0, tuple_shape, "param"); + + // Add each element's tuple index value to every element. + std::vector result_elements; + for (int i = 0; i < kElementCount; ++i) { + auto element = builder.GetTupleElement(param, i); + result_elements.push_back( + builder.Add(element, builder.ConstantR0(i))); + } + builder.Tuple(result_elements); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Feed in a tuple where each two-element vector element is {tuple_index, + // -tuple_index}. + std::vector> arg_elements; + for (int i = 0; i < kElementCount; ++i) { + arg_elements.push_back(Literal::CreateR1({1.0f * i, -1.0f * i})); + } + std::unique_ptr arg_literal = + Literal::MakeTupleOwned(std::move(arg_elements)); + auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + + std::unique_ptr result = + ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + + std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + + for (int i = 0; i < kElementCount; ++i) { + LiteralTestUtil::ExpectR1Near( + {2.0f * i, 0.0f}, result_literal->tuple_literals(i), error_spec_); + } +} + +// TODO(b/66968986): Test times out on CPU parallel backend. Disabled +// 2017-09-26. +XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) { + // Construct and run a computation which takes a two-level nested tuple + // parameter with a large fanout. + const int kFanout = 40; + + // Tuple shape is full two-level tree with the given fanout. + const Shape element_shape = ShapeUtil::MakeShape(F32, {}); + std::vector element_shapes(kFanout, element_shape); + const Shape inner_tuple_shape = ShapeUtil::MakeTupleShape(element_shapes); + std::vector inner_tuple_shapes(kFanout, inner_tuple_shape); + const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes); + + ComputationBuilder builder(local_client_, TestName()); + auto param = builder.Parameter(0, tuple_shape, "param"); + + // The computation increments each leaf value by an amount equal to the leaf's + // ordinal position in a traversal of the tuple. + std::vector result_elements; + for (int i = 0; i < kFanout; ++i) { + auto outer_element = builder.GetTupleElement(param, i); + std::vector inner_result_elements; + for (int j = 0; j < kFanout; ++j) { + auto inner_element = builder.GetTupleElement(outer_element, j); + inner_result_elements.push_back(builder.Add( + inner_element, builder.ConstantR0(i * kFanout + j))); + } + result_elements.push_back(builder.Tuple(inner_result_elements)); + } + builder.Tuple(result_elements); + auto computation = builder.Build().ConsumeValueOrDie(); + + // Construct the argument to pass to the computation. + std::vector> outer_tuple_elements; + for (int i = 0; i < kFanout; ++i) { + std::vector> inner_tuple_elements; + for (int j = 0; j < kFanout; ++j) { + inner_tuple_elements.push_back(Literal::CreateR0(i + j)); + } + outer_tuple_elements.push_back( + Literal::MakeTupleOwned(std::move(inner_tuple_elements))); + } + auto arg_literal = Literal::MakeTupleOwned(std::move(outer_tuple_elements)); + auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + + std::unique_ptr result = + ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + + for (int i = 0; i < kFanout; ++i) { + for (int j = 0; j < kFanout; ++j) { + LiteralTestUtil::ExpectR0Near( + i + j + i * kFanout + j, + result_literal->tuple_literals(i).tuple_literals(j), error_spec_); + } + } +} + +XLA_TEST_F(LocalClientExecuteTest, DeepTuple) { + // Construct and run a computation which takes a very deep tuple. The tuple + // has no fan out and a single scalar element at the bottom. + const int kTupleDepth = 100; + + // Tuple shape is full two-level tree with the given fanout. + Shape shape = ShapeUtil::MakeShape(F32, {}); + for (int i = 0; i < kTupleDepth; ++i) { + shape = ShapeUtil::MakeTupleShape({shape}); + } + + ComputationBuilder builder(local_client_, TestName()); + auto element = builder.Parameter(0, shape, "param"); + for (int i = 0; i < kTupleDepth; ++i) { + element = builder.GetTupleElement(element, 0); + } + + auto output = builder.Add(element, builder.ConstantR0(42.0)); + for (int i = 0; i < kTupleDepth; ++i) { + output = builder.Tuple({output}); + } + auto computation = builder.Build().ConsumeValueOrDie(); + + // Construct the argument to pass to the computation. + std::unique_ptr arg_literal = Literal::CreateR0(123.0); + for (int i = 0; i < kTupleDepth; ++i) { + std::vector> arg_vector; + arg_vector.push_back(std::move(arg_literal)); + arg_literal = Literal::MakeTupleOwned(std::move(arg_vector)); + } + auto arg_buffer = LiteralToShapedBuffer(*arg_literal); + + std::unique_ptr result = + ExecuteLocallyOrDie(computation, {arg_buffer.get()}); + std::unique_ptr result_literal = ShapedBufferToLiteral(*result); + + const Literal* result_element = result_literal.get(); + for (int i = 0; i < kTupleDepth; ++i) { + result_element = &result_element->tuple_literals(0); + } + LiteralTestUtil::ExpectR0Equal(165.0, *result_element); +} + XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { // Test passing in an invalid number of arguments. ComputationBuilder builder(local_client_, TestName()); @@ -292,8 +570,8 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) { auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y"); builder.Add(x, y); - auto x_array = LiteralToScopedShapedBuffer( - *Literal::CreateR1({1.0f, 2.0f, 3.0f})); + auto x_array = + LiteralToShapedBuffer(*Literal::CreateR1({1.0f, 2.0f, 3.0f})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); @@ -308,7 +586,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) { auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x"); builder.Neg(x); - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally(builder.Build().ValueOrDie(), {x_array.get()}); @@ -325,7 +603,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) { auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x"); builder.Neg(x); - auto x_array = LiteralToScopedShapedBuffer( + auto x_array = LiteralToShapedBuffer( *Literal::CreateR2({{0.0f, 1.0f}, {2.0f, 3.0f}})); auto execute_status = ExecuteLocally( builder.Build().ValueOrDie(), {x_array.get()}, @@ -508,12 +786,11 @@ XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) { std::unique_ptr executable = executable_status.ConsumeValueOrDie(); - auto x_array = LiteralToScopedShapedBuffer( - *Literal::CreateR1({0.0f, 1.0f, 2.0f})); - std::unique_ptr result = ShapedBufferToScopedShapedBuffer( + auto x_array = + LiteralToShapedBuffer(*Literal::CreateR1({0.0f, 1.0f, 2.0f})); + std::unique_ptr result = executable->Run({x_array.get()}, DefaultExecutableRunOptions()) - .ConsumeValueOrDie(), - allocator_); + .ConsumeValueOrDie(); LiteralTestUtil::ExpectR1Near( {2.0f, 4.0f, 6.0f}, *ShapedBufferToLiteral(*result), error_spec_); @@ -526,7 +803,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { TF_ASSERT_OK_AND_ASSIGN( auto shaped_buffer, local_client_->LiteralToShapedBuffer( - literal, allocator_, local_client_->default_device_ordinal())); + literal, local_client_->default_device_ordinal(), allocator_)); TF_ASSERT_OK_AND_ASSIGN( auto transferred_literal, local_client_->ShapedBufferToLiteral(*shaped_buffer)); @@ -538,7 +815,7 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { test_to_device_and_back(*Literal::CreateR0(true)); test_to_device_and_back(*Literal::CreateR1({1.0, 42.0, 744.4})); test_to_device_and_back( - *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); // Null shape (empty tuple). @@ -559,6 +836,55 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion) { Literal::CreateR0(false).get()})); } +XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) { + // Test copying Literals to the device as ShapedBuffers, then copying them + // back again to Literals for 64-bit values. + auto test_to_device_and_back = [this](const Literal& literal) { + TF_ASSERT_OK_AND_ASSIGN( + auto shaped_buffer, + local_client_->LiteralToShapedBuffer( + literal, local_client_->default_device_ordinal(), allocator_)); + TF_ASSERT_OK_AND_ASSIGN( + auto transferred_literal, + local_client_->ShapedBufferToLiteral(*shaped_buffer)); + EXPECT_EQ(literal, *transferred_literal); + }; + + test_to_device_and_back( + *Literal::CreateR2({{1.0, 2.0, 3.0}, {44.0, 0.1, -3}})); + test_to_device_and_back(*Literal::CreateR2({{2, 1}, {4444, 56}})); + test_to_device_and_back( + *Literal::CreateR2({{20000000000ULL, 1}, {4444, 56}})); + test_to_device_and_back( + *Literal::MakeTuple({Literal::CreateR1({1.0, -42.0}).get(), + Literal::CreateR0(123456789000LL).get()})); +} + +// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel. +// 2017-10-18. +XLA_TEST_F(LocalClientExecuteTest, + DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(InfeedOutfeedTest))) { + ComputationBuilder builder(local_client_, TestName()); + const Shape shape = ShapeUtil::MakeShape(F32, {3}); + auto in = builder.Infeed(shape); + auto constant = builder.ConstantR1({1.0f, 2.0f, 3.0f}); + auto sum = builder.Add(in, constant); + builder.Outfeed(sum, shape, /*outfeed_config=*/""); + + std::unique_ptr thread( + tensorflow::Env::Default()->StartThread( + tensorflow::ThreadOptions(), "execute_thread", + [&] { ExecuteLocallyOrDie(builder.Build().ValueOrDie(), {}); })); + + ASSERT_IS_OK(local_client_->TransferToInfeed( + *Literal::CreateR1({-5.0, 123.0, 42.0}))); + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr result, + local_client_->TransferFromOutfeed(&shape)); + + LiteralTestUtil::ExpectR1Equal({-4.0, 125.0, 45.0}, *result); +} + // Benchmark that measures the overhead of the LocalClient API when running a // trivial computation void BM_LocalClientOverhead(int num_iters) { @@ -580,8 +906,9 @@ void BM_LocalClientOverhead(int num_iters) { builder.Add(x, x); auto computation = builder.Build().ConsumeValueOrDie(); - auto buffer = ScopedShapedBuffer::MakeScopedShapedBuffer(shape, &allocator, 0) - .ConsumeValueOrDie(); + auto buffer = + ScopedShapedBuffer::Allocate(shape, &allocator, /*device_ordinal=*/0) + .ConsumeValueOrDie(); auto literal = Literal::CreateR2({{0, 0, 0}, {0, 0, 0}}); ASSERT_IS_OK(transfer_manager->TransferLiteralToDevice( executors[device_ordinal], *literal, buffer->mutable_buffer({}))); diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc index 49207356e3027cff52a29f962fedbd3593a4925e..c11e1df0a7890a6c3aada5ff47494b42fdaf3b9d 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.cc +++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc @@ -90,6 +90,9 @@ int64 TestAllocator::deallocation_count(int device_ordinal) const { /* static */ TestAllocator* LocalClientTestBase::GetOrCreateAllocator( perftools::gputools::Platform* platform) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + tensorflow::mutex_lock lock(mu); + if (allocator_ == nullptr) { allocator_ = new TestAllocator( platform == nullptr ? PlatformUtil::GetDefaultPlatform().ValueOrDie() @@ -126,27 +129,11 @@ LocalClientTestBase::LocalClientTestBase( LocalClientTestBase::~LocalClientTestBase() {} -std::unique_ptr -LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal) { - return LiteralToScopedShapedBuffer(literal, - local_client_->default_device_ordinal()); -} - -std::unique_ptr -LocalClientTestBase::LiteralToScopedShapedBuffer(const Literal& literal, - int device_ordinal) { - CHECK(!ShapeUtil::IsTuple(literal.shape())); - auto scoped_buffer = - ScopedShapedBuffer::MakeScopedShapedBuffer( - literal.shape(), GetOrCreateAllocator(local_client_->platform()), - device_ordinal) - .ConsumeValueOrDie(); - // The creation of the scoped shaped buffer should allocate the buffer. - CHECK(!scoped_buffer->buffer(/*index=*/{}).is_null() || - ShapeUtil::HasZeroElements(literal.shape())); - TF_CHECK_OK(transfer_manager_->TransferLiteralToDevice( - stream_executor_, literal, scoped_buffer->mutable_buffer(/*index=*/{}))); - return scoped_buffer; +std::unique_ptr LocalClientTestBase::LiteralToShapedBuffer( + const Literal& literal) { + return local_client_ + ->LiteralToShapedBuffer(literal, local_client_->default_device_ordinal()) + .ConsumeValueOrDie(); } void LocalClientTestBase::CopyShapedBufferToLiteral( @@ -174,33 +161,6 @@ std::unique_ptr LocalClientTestBase::ShapedBufferToLiteral( return literal; } -std::unique_ptr -LocalClientTestBase::ShapedBufferToScopedShapedBuffer( - std::unique_ptr shaped_buffer, - DeviceMemoryAllocator* allocator) { - std::unique_ptr scoped_buffer = - ScopedShapedBuffer::MakeScopedShapedBuffer( - shaped_buffer->shape(), allocator, shaped_buffer->device_ordinal()) - .ConsumeValueOrDie(); - // Deallocate the existing DeviceMemoryBase values in the newly created scoped - // buffer and replace them with the values from the shaped buffer. - for (perftools::gputools::DeviceMemoryBase& memory_base : - *scoped_buffer->mutable_buffers()) { - TF_CHECK_OK( - allocator->Deallocate(shaped_buffer->device_ordinal(), &memory_base)); - } - *scoped_buffer->mutable_buffers() = shaped_buffer->buffers(); - - scoped_buffer->mutable_shape_index_to_buffer_entry()->ForEachMutableElement( - [&shaped_buffer](const ShapeIndex& index, size_t* buffer_entry) { - if (ShapeUtil::IsLeafIndex(shaped_buffer->shape(), index)) { - *buffer_entry = - shaped_buffer->shape_index_to_buffer_entry().element(index); - } - }); - return scoped_buffer; -} - ExecutableBuildOptions LocalClientTestBase::DefaultExecutableBuildOptions() const { return ExecutableBuildOptions(); @@ -253,10 +213,7 @@ LocalClientTestBase::ExecuteLocally( TF_ASSIGN_OR_RETURN( std::unique_ptr executable, local_client_->Compile(computation, argument_layouts, build_options)); - TF_ASSIGN_OR_RETURN(std::unique_ptr buffer, - executable->Run(arguments, run_options)); - return ShapedBufferToScopedShapedBuffer(std::move(buffer), - run_options.allocator()); + return executable->Run(arguments, run_options); } } // namespace xla diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h index e3c3bb46cf26cc742b7abb39a3e457d823d829ec..3edfcb656ed8278d403103f0cfd820a10892476a 100644 --- a/tensorflow/compiler/xla/tests/local_client_test_base.h +++ b/tensorflow/compiler/xla/tests/local_client_test_base.h @@ -83,12 +83,10 @@ class LocalClientTestBase : public ::testing::Test { perftools::gputools::Platform* platform); // Copy the given literal onto the default device and return a - // ScopedShapedBuffer. - std::unique_ptr LiteralToScopedShapedBuffer( + // ScopedShapedBuffer. Convenience wrapper around + // LocalClient::LiteralToShapedBuffer. + std::unique_ptr LiteralToShapedBuffer( const Literal& literal); - // As above, but copy to a specific device. - std::unique_ptr LiteralToScopedShapedBuffer( - const Literal& literal, int device_ordinal); // Construct and return a literal containing the array represented by // shaped_buffer. @@ -126,18 +124,12 @@ class LocalClientTestBase : public ::testing::Test { // as the allocator. ExecutableRunOptions DefaultExecutableRunOptions() const; - // Convert a ShapedBuffer into a ScopedShaped buffer so that all buffers are - // deallocated when the object is destructed. - std::unique_ptr ShapedBufferToScopedShapedBuffer( - std::unique_ptr shaped_buffer, - DeviceMemoryAllocator* allocator); - string TestName() const { return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } - // The allocator must live as long as the service which lives until the end of - // the process, so make the allocator static. + // The allocator must live as long as the service, which lives until the end + // of the process. So make the allocator static. static TestAllocator* allocator_; perftools::gputools::StreamExecutor* stream_executor_; diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc index 01ee421baac3b17da962d9ddc7b15b8e6039200a..2ef392508d14cf6dc14b2c979f07a79bc60d7426 100644 --- a/tensorflow/compiler/xla/tests/map_test.cc +++ b/tensorflow/compiler/xla/tests/map_test.cc @@ -125,7 +125,7 @@ class MapTest : public ClientLibraryTestBase { Computation CreateMapPlusN(const Computation& embedded_computation, float n) { ComputationBuilder builder(client_, TestName()); auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - auto map = builder.Map({x}, embedded_computation); + auto map = builder.Map({x}, embedded_computation, {}); auto constant_n = builder.ConstantR0(n); auto add = builder.Add(map, constant_n); auto computation_status = builder.Build(); @@ -173,7 +173,7 @@ TEST_F(MapTest, MapEachElemPlusOneR0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {}); ComputeAndCompareR0(&builder, 43.0, {param0_data.get()}, ErrorSpec(0.01f)); @@ -187,7 +187,7 @@ XLA_TEST_F(MapTest, MapEachElemPlusOneR1S0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -202,7 +202,7 @@ TEST_F(MapTest, MapEachElemPlusOneR1S4) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0}); ComputeAndCompareR1(&builder, {3.2f, 4.3f, 5.4f, 6.5f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -216,7 +216,7 @@ TEST_F(MapTest, MapEachF32ElementToS32Constant) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne()); + auto map = builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -229,7 +229,7 @@ TEST_F(MapTest, MapEachF32ElementToU32Constant) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateScalarOne()); + auto map = builder.Map({param}, CreateScalarOne(), {0}); ComputeAndCompareR1(&builder, {1, 1, 1, 1}, {param0_data.get()}); } @@ -243,7 +243,7 @@ TEST_F(MapTest, MapEachElemLongerChainR1) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOneTimesItself()); + auto map = builder.Map({param}, CreateAdderToOneTimesItself(), {0}); ComputeAndCompareR1( &builder, {9.36f, 20.91f, 0.11f, 0.24f, 999000.0f, 65535.75f}, @@ -259,8 +259,8 @@ XLA_TEST_F(MapTest, MapMultipleMapsR1S0) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne()); - auto map2 = builder.Map({map1}, CreateMulByTwo()); + auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); + auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -276,8 +276,8 @@ TEST_F(MapTest, MapMultipleMapsR1S4) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map1 = builder.Map({param}, CreateAdderToOne()); - auto map2 = builder.Map({map1}, CreateMulByTwo()); + auto map1 = builder.Map({param}, CreateAdderToOne(), {0}); + auto map2 = builder.Map({map1}, CreateMulByTwo(), {0}); ComputeAndCompareR1(&builder, {6.4f, 8.6f, 10.8f, 13.0f}, {param0_data.get()}, ErrorSpec(0.01f)); @@ -292,7 +292,7 @@ TEST_F(MapTest, MapEachElemPlusOneR2) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param = builder.Parameter(0, param0_literal->shape(), "param0"); - auto map = builder.Map({param}, CreateAdderToOne()); + auto map = builder.Map({param}, CreateAdderToOne(), {0, 1}); Array2D expected_array( {{14.25f, 15.0f}, {-6.1f, -6.2f}, {-7.8f, 9.8f}}); @@ -319,8 +319,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputationBuilder embed4_builder(client_, "embed4"); auto embed4_param = embed4_builder.Parameter(0, scalar_shape, "x"); - auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2); - auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3); + auto embed4_map_lhs = embed4_builder.Map({embed4_param}, embed2, {}); + auto embed4_map_rhs = embed4_builder.Map({embed4_param}, embed3, {}); auto embed4_add = embed4_builder.Add(embed4_map_lhs, embed4_map_rhs); auto embed4_status = embed4_builder.Build(); ASSERT_IS_OK(embed4_status.status()); @@ -331,8 +331,8 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) { ComputationBuilder builder(client_, TestName()); auto constant_42 = builder.ConstantR0(42.0); auto constant_7 = builder.ConstantR0(7.0); - auto map_42 = builder.Map({constant_42}, embed5); - auto map_7 = builder.Map({constant_7}, embed4); + auto map_42 = builder.Map({constant_42}, embed5, {}); + auto map_7 = builder.Map({constant_7}, embed4, {}); builder.Add(map_42, map_7); ComputeAndCompareR0(&builder, 73.0, {}, ErrorSpec(0.01f)); @@ -355,7 +355,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { ComputationBuilder builder(client_, TestName()); auto constant_vector = builder.ConstantR1({1.0, 2.0, 3.0, 4.0}); - auto map_plus_1 = builder.Map({constant_vector}, embedded_computation); + auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0}); // Add another Add(1) operation to the existing embedded computation. This // requires using the stub interface because the ComputationBuilder does not @@ -371,7 +371,7 @@ TEST_F(MapTest, VersionedEmbeddedComputation) { tensorflow::Status s = client_->stub()->Op(&op_request, &response); ASSERT_TRUE(s.ok()); - auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation); + auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0}); // The original vector has Add(1) applied to it with a map, followed by // Add(1+1) resulting in a net Add(3). @@ -393,8 +393,8 @@ TEST_F(MapTest, MapBinaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(F32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(F32, &builder), {0}); ComputeAndCompareR1(&builder, {7.3f, 7.7, 4.3f, 0}, {param0_data.get(), param1_data.get()}, @@ -417,8 +417,8 @@ XLA_TEST_F(MapTest, AddWithMixedLayouts) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(S32, &builder), {0, 1}); Array2D expected(2, 2); expected(0, 0) = 11; @@ -443,8 +443,8 @@ XLA_TEST_F(MapTest, AddR3_3x0x2) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = - builder.Map({param0, param1}, CreateScalarAddComputation(S32, &builder)); + auto map = builder.Map({param0, param1}, + CreateScalarAddComputation(S32, &builder), {0, 1, 2}); ComputeAndCompareR3(&builder, Array3D(3, 0, 2), {param0_data.get(), param1_data.get()}); @@ -469,7 +469,7 @@ TEST_F(MapTest, MapTernaryAdder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); auto param2 = builder.Parameter(2, param2_literal->shape(), "param2"); - auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder()); + auto map = builder.Map({param0, param1, param2}, CreateTernaryAdder(), {0}); ComputeAndCompareR1( &builder, {-2.7f, -92.3f, -895.7f, -400.0f}, @@ -481,7 +481,7 @@ TEST_F(MapTest, MapGt) { // Maps (x,y) -> x > y onto two R1F32 vectors. ComputationBuilder b(client_, TestName()); auto gt = CreateGt(); - b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt); + b.Map({b.ConstantR1({1, 20}), b.ConstantR1({10, 2})}, gt, {0}); ComputeAndCompareR1(&b, {false, true}, {}); } @@ -491,14 +491,14 @@ TEST_F(MapTest, NestedBinaryMap) { // max_with_square(x) = do max(x, x^2) via a map. ComputationBuilder b(client_, "max_with_square"); auto x = b.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x"); - b.Map({x, b.Mul(x, x)}, CreateMax()); + b.Map({x, b.Mul(x, x)}, CreateMax(), {}); auto computation_status = b.Build(); ASSERT_IS_OK(computation_status.status()); max_with_square = computation_status.ConsumeValueOrDie(); } ComputationBuilder b(client_, TestName()); auto input = b.ConstantR1({0.1f, 0.5f, -0.5f, 1.0f, 2.0f}); - b.Map({input}, max_with_square); + b.Map({input}, max_with_square, {0}); ComputeAndCompareR1(&b, {0.1f, 0.5f, 0.25f, 1.0f, 4.0f}, {}); } @@ -525,7 +525,7 @@ TEST_F(MapTest, MapOperantionWithBuildError) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - auto map = builder.Map({param0, param1}, error_add); + auto map = builder.Map({param0, param1}, error_add, {0}); StatusOr computation_status = builder.Build(); ASSERT_TRUE(!computation_status.ok()); @@ -562,7 +562,7 @@ TEST_F(MapTestWithFullOpt, MapScalarPower) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, power); + builder.Map({param0, param1}, power, {}); ComputeAndCompareR0(&builder, 32.0f, {param0_data.get(), param1_data.get()}, @@ -589,7 +589,7 @@ TEST_F(MapTestWithFullOpt, MapSubtractOppositeOrder) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto param1 = builder.Parameter(1, param1_literal->shape(), "param1"); - builder.Map({param0, param1}, sub_opposite); + builder.Map({param0, param1}, sub_opposite, {}); ComputeAndCompareR0( &builder, 3.0f, {param0_data.get(), param1_data.get()}, ErrorSpec(0.01f)); @@ -610,7 +610,7 @@ TEST_F(MapTestWithFullOpt, MapSquare) { client_->TransferToServer(*param0_literal).ConsumeValueOrDie(); auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); - builder.Map({param0}, square); + builder.Map({param0}, square, {}); ComputeAndCompareR0(&builder, 100.0f, {param0_data.get()}, ErrorSpec(0.01f)); diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc index 4c33bb2c3661f185c93f798cd4e989f0b39178c1..0fb87c3c2ccbad387d46016cfad4e7d3cc537dcc 100644 --- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc @@ -111,7 +111,7 @@ TEST_F(MatOpsSimpleTest, MapTwoByTwo) { {1.0, 0.0}, // row 0 {-1.0, 0.5}, // row 1 }); - auto map = builder.Map({data}, add_half); + auto map = builder.Map({data}, add_half, {0, 1}); std::unique_ptr expected = Literal::CreateR2({{1.5, 0.5}, // row 0 diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc index 3500e8dc28570fe216f53b746c3757e080aa689f..10e44b274a8a9f3ac28dc40d7b1938d24a9ee40c 100644 --- a/tensorflow/compiler/xla/tests/pred_test.cc +++ b/tensorflow/compiler/xla/tests/pred_test.cc @@ -90,7 +90,7 @@ TEST_F(PredTest, ConstantR2Pred) { builder.ConstantR2({{false, true, true}, {true, false, false}}); const string expected = R"(pred[2,3] { { 011 }, - { 100 }, + { 100 } })"; EXPECT_EQ(expected, ExecuteToString(&builder, {})); } @@ -119,7 +119,9 @@ TEST_F(PredTest, AnyR1VacuouslyFalse) { TEST_F(PredTest, AnyR2True) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({ - {false, false, false}, {false, false, false}, {false, false, true}, + {false, false, false}, + {false, false, false}, + {false, false, true}, }); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, true, {}); @@ -128,7 +130,9 @@ TEST_F(PredTest, AnyR2True) { TEST_F(PredTest, AnyR2False) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({ - {false, false, false}, {false, false, false}, {false, false, false}, + {false, false, false}, + {false, false, false}, + {false, false, false}, }); TF_ASSERT_OK(Any(a, &builder).status()); ComputeAndCompareR0(&builder, false, {}); diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 0f82291fea6559381b60a610222a869c999f64cf..209f063cc5a34648453d12deae79f261b95dc3b4 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -170,7 +170,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) { auto param0 = builder.Parameter(0, param0_literal->shape(), "param0"); auto fn = build_sum_rng(builder); - builder.Map({param0}, fn); + builder.Map({param0}, fn, {0}); TF_ASSERT_OK_AND_ASSIGN(auto computation, builder.Build()); diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc index 2271f32c5946f3d3e7e6b43b089e68ab3101b61b..7bc3185c367f076c9a7d211c9799557e1a91d92f 100644 --- a/tensorflow/compiler/xla/tests/reduce_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_test.cc @@ -120,10 +120,10 @@ class ReduceTest : public ClientLibraryTestBase { Computation reduce; if (and_reduce) { init_value = builder.ConstantR0(true); - reduce = CreateScalarLogicalAndComputation(&builder); + reduce = CreateScalarAndComputation(&builder); } else { init_value = builder.ConstantR0(false); - reduce = CreateScalarLogicalOrComputation(&builder); + reduce = CreateScalarOrComputation(&builder); } builder.Reduce(pred_values, init_value, reduce, /*dimensions_to_reduce=*/{0}); @@ -457,7 +457,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2}); auto input = builder.Parameter(0, input_shape, "input"); auto zero = builder.ConstantR0(0.0); - auto log_ = builder.Log(input); + auto log_ = builder.Tanh(input); auto reshape = builder.Reshape(log_, {rows, cols}); builder.Reduce(reshape, zero, add_f32, /*dimensions_to_reduce=*/{0}); @@ -473,7 +473,7 @@ XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) { for (int64 colno = 0; colno < cols / 2; ++colno) { float column_sum = 0; for (int64 rowno = 0; rowno < rows; ++rowno) { - column_sum += log(input_data(rowno, major, colno)); + column_sum += tanh(input_data(rowno, major, colno)); } expected.push_back(column_sum); } @@ -502,8 +502,8 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) { ComputationBuilder builder(client_, TestName()); auto add = CreateScalarAddComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); - auto broacasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broacasted, builder.ConstantR0(0.0f), add, {0, 1}); + auto broadcasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broadcasted, builder.ConstantR0(0.0f), add, {0, 1}); float expected = 42.0f * static_cast(500 * 500); ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -514,8 +514,8 @@ XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) { ComputationBuilder builder(client_, TestName()); auto max = CreateScalarMaxComputation(F32, &builder); auto scalar = builder.ConstantR0(42.0); - auto broacasted = builder.Broadcast(scalar, {500, 500}); - builder.Reduce(broacasted, builder.ConstantR0(0.0f), max, {0, 1}); + auto broadcasted = builder.Broadcast(scalar, {500, 500}); + builder.Reduce(broadcasted, builder.ConstantR0(0.0f), max, {0, 1}); float expected = 42.0f; ComputeAndCompareR0(&builder, expected, {}, ErrorSpec(0.0001)); @@ -729,16 +729,14 @@ XLA_TEST_F(ReduceTest, VectorizedReduce_Min) { std::numeric_limits::max()); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalAnd) { - RunVectorizedReduceTestForType(CreateScalarLogicalAndComputation, - [](bool a, bool b) { return a && b; }, - true); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanAnd) { + RunVectorizedReduceTestForType( + CreateScalarAndComputation, [](bool a, bool b) { return a && b; }, true); } -XLA_TEST_F(ReduceTest, VectorizedReduce_LogicalOr) { - RunVectorizedReduceTestForType(CreateScalarLogicalOrComputation, - [](bool a, bool b) { return a || b; }, - false); +XLA_TEST_F(ReduceTest, VectorizedReduce_BooleanOr) { + RunVectorizedReduceTestForType( + CreateScalarOrComputation, [](bool a, bool b) { return a || b; }, false); } class ReduceR3ToR2Test : public ReduceTest, diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc index 7b7f2687286916aff9c47a7c165619bbe84368e8..6c9b62b48d8bb2ad93b2ce98839e5e52d8eaa8cc 100644 --- a/tensorflow/compiler/xla/tests/reduce_window_test.cc +++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc @@ -76,6 +76,20 @@ class ReduceWindowTest : public ClientLibraryTestBase { ComputationBuilder builder_; }; +TEST_F(ReduceWindowTest, MismatchedRanksGivesErrorStatus) { + const auto input = builder_.ConstantR1({1, 1, 1, 1}); + const auto init_value = builder_.ConstantR0(0); + TF_ASSERT_OK(builder_.first_error()); + builder_.ReduceWindow(input, init_value, + CreateScalarAddComputation(F32, &builder_), + /*window_dimensions=*/{1, 2}, + /*window_strides=*/{1}, Padding::kValid); + ASSERT_EQ(builder_.first_error().code(), tensorflow::error::INVALID_ARGUMENT) + << builder_.first_error(); + ASSERT_THAT(builder_.first_error().error_message(), + ::testing::HasSubstr("Want input dimensions size")); +} + TEST_F(ReduceWindowTest, Min3In5Stride2) { const auto input = builder_.ConstantR1({10000, 1000, 100, 10, 1}); ReduceWindowMin(input, {3}, {2}, Padding::kValid); diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc index 92efd2947d6384d4ffaf6dc0134ddaf313ddedf7..6d063ffc363c092a1fbc40cbc22e87181d0c2502 100644 --- a/tensorflow/compiler/xla/tests/replay_test.cc +++ b/tensorflow/compiler/xla/tests/replay_test.cc @@ -117,7 +117,7 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) { ComputationBuilder mapper_builder(client_, TestName()); auto original = mapper_builder.ConstantR1({1, 2, 3}); - mapper_builder.Map({original}, plus_two); + mapper_builder.Map({original}, plus_two, {0}); Computation computation = mapper_builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tests/reshape_test.cc b/tensorflow/compiler/xla/tests/reshape_test.cc index bb7160e3a03053a4f3d8da712c1424e50f37dfeb..72c68f24a0a954deb0564e9a0e924edfaf5b5484 100644 --- a/tensorflow/compiler/xla/tests/reshape_test.cc +++ b/tensorflow/compiler/xla/tests/reshape_test.cc @@ -47,7 +47,7 @@ class ReshapeTest : public ClientLibraryTestBase { }; // Collapses 2-dimensional pseudo-scalar (single-element array) to 1 dimension. -XLA_TEST_F(ReshapeTest, Trivial1x1) { +XLA_TEST_F(ReshapeTest, CollapseTrivial1x1) { ComputationBuilder builder(client_, TestName()); auto a = builder.ConstantR2({{1.0}}); builder.Collapse(/*operand=*/a, /*dimensions=*/{0, 1}); @@ -55,6 +55,22 @@ XLA_TEST_F(ReshapeTest, Trivial1x1) { ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); } +XLA_TEST_F(ReshapeTest, CollapseTrivialR1EmptyDims) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + +XLA_TEST_F(ReshapeTest, CollapseTrivialR1OnlyDim) { + ComputationBuilder builder(client_, TestName()); + auto a = builder.ConstantR1({1.0}); + builder.Collapse(/*operand=*/a, /*dimensions=*/{0}); + + ComputeAndCompareR1(&builder, {1.0f}, {}, zero_error_spec_); +} + // Collapses 2-dimensional pseudo-scalar (single-element array) to scalar. XLA_TEST_F(ReshapeTest, SingleElementArrayToScalar) { ComputationBuilder builder(client_, TestName()); diff --git a/tensorflow/compiler/xla/tests/scalar_computations_test.cc b/tensorflow/compiler/xla/tests/scalar_computations_test.cc index 77d1c019f3a23f79237e624dabf8972a6c1d3c72..b5e7570778ffeca66cc15d7cd2b153639637a647 100644 --- a/tensorflow/compiler/xla/tests/scalar_computations_test.cc +++ b/tensorflow/compiler/xla/tests/scalar_computations_test.cc @@ -459,39 +459,99 @@ XLA_TEST_F(ScalarComputationsTest, RemTwoScalarsU32) { ComputeAndCompareR0(&builder, 2, {}); } -XLA_TEST_F(ScalarComputationsTest, LogicalAnd) { +XLA_TEST_F(ScalarComputationsTest, AndBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalAnd(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x && y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalOr) { +XLA_TEST_F(ScalarComputationsTest, AndS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, AndU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.And(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x & y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrBool) { for (bool x : {false, true}) { for (bool y : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalOr(builder.ConstantR0(x), - builder.ConstantR0(y)); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); ComputeAndCompareR0(&builder, x || y, {}); } } } -XLA_TEST_F(ScalarComputationsTest, LogicalNot) { +XLA_TEST_F(ScalarComputationsTest, OrS32) { + for (int32 x : {0, 8}) { + for (int32 y : {1, -16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, OrU32) { + for (uint32 x : {0, 8}) { + for (uint32 y : {1, 16}) { + ComputationBuilder builder(client_, TestName()); + builder.Or(builder.ConstantR0(x), builder.ConstantR0(y)); + + ComputeAndCompareR0(&builder, x | y, {}); + } + } +} + +XLA_TEST_F(ScalarComputationsTest, NotBool) { for (bool x : {false, true}) { ComputationBuilder builder(client_, TestName()); - builder.LogicalNot(builder.ConstantR0(x)); + builder.Not(builder.ConstantR0(x)); ComputeAndCompareR0(&builder, !x, {}); } } +XLA_TEST_F(ScalarComputationsTest, NotS32) { + for (int32 x : {-1, 0, 1}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + +XLA_TEST_F(ScalarComputationsTest, NotU32) { + for (uint32 x : {0, 1, 2}) { + ComputationBuilder builder(client_, TestName()); + builder.Not(builder.ConstantR0(x)); + + ComputeAndCompareR0(&builder, ~x, {}); + } +} + XLA_TEST_F(ScalarComputationsTest, SelectScalarTrue) { ComputationBuilder builder(client_, TestName()); builder.Select(builder.ConstantR0(true), // The predicate. diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc index 553377894796130fb8c4dd9db878149243e3d711..4920f17a7ed21d587c15b8deac550d5e5bb566c9 100644 --- a/tensorflow/compiler/xla/tests/tuple_test.cc +++ b/tensorflow/compiler/xla/tests/tuple_test.cc @@ -293,7 +293,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) { ComputationBuilder b(client_, TestName()); auto input = b.ConstantR1({-1.0f, 1.0f, 2.1f}); - b.Map({input}, tuple_computation); + b.Map({input}, tuple_computation, {0}); ComputeAndCompareR1(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_); } diff --git a/tensorflow/compiler/xla/tests/unary_op_test.cc b/tensorflow/compiler/xla/tests/unary_op_test.cc index efae13a43a058b03a45174c8260bce2ed70cb75c..fa4192e9281784a4a3063601afe89fba6a9dac18 100644 --- a/tensorflow/compiler/xla/tests/unary_op_test.cc +++ b/tensorflow/compiler/xla/tests/unary_op_test.cc @@ -41,7 +41,11 @@ class UnaryOpTest : public ClientLibraryTestBase { auto arg = builder.ConstantR1({}); auto abs = builder.Abs(arg); - ComputeAndCompareR1(&builder, {}, {}); + if (primitive_util::NativeToPrimitiveType() == C64) { + ComputeAndCompareR1(&builder, {}, {}); + } else { + ComputeAndCompareR1(&builder, {}, {}); + } } template @@ -80,14 +84,58 @@ int UnaryOpTest::inf() { return 2147483647; } +template <> +void UnaryOpTest::AbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1({{-2, 0}, + {0, 25}, + {0, 0}, + {-0.3f, 0.4f}, + {0, inf()}, + {-inf(), 0}}); + auto abs = builder.Abs(arg); + + std::unique_ptr expected = + Literal::CreateR1({2, 25, 0, 0.5, inf(), inf()}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + +template <> +void UnaryOpTest::SignTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = builder.ConstantR1( + {{-2, 0}, {0, 25}, {0, 0}, {static_cast(-0.0), 0}, {-1, 1}}); + auto sign = builder.Sign(arg); + + std::unique_ptr expected = Literal::CreateR1( + {{-1, 0}, {0, 1}, {0, 0}, {0, 0}, {-std::sqrt(0.5f), std::sqrt(0.5f)}}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + +template <> +void UnaryOpTest::SignAbsTestHelper() { + ComputationBuilder builder(client_, TestName()); + auto arg = + builder.ConstantR1({{-2, 0}, {0, 25}, {0, 0}, {-0.4, 0.3}}); + auto sign = builder.Sign(arg); + auto abs = builder.Abs(arg); + builder.Sub(builder.Mul(sign, builder.ConvertElementType(abs, C64)), arg); + + std::unique_ptr expected = + Literal::CreateR1({0, 0, 0, 0}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); +} + XLA_TEST_F(UnaryOpTest, AbsTestR1Size0) { AbsSize0TestHelper(); AbsSize0TestHelper(); + AbsSize0TestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR1) { AbsTestHelper(); AbsTestHelper(); + AbsTestHelper(); } XLA_TEST_F(UnaryOpTest, AbsTestR0) { @@ -98,34 +146,44 @@ XLA_TEST_F(UnaryOpTest, AbsTestR0) { auto absf = builder.Abs(argf); auto argf0 = builder.ConstantR0(-0.0f); auto absf0 = builder.Abs(argf0); - builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( - absi, PrimitiveType::F32))); + auto argc = builder.ConstantR0({-0.3f, 0.4f}); + auto absc = builder.Abs(argc); + builder.Add(builder.Add(absc, absf0), + builder.Add(absf, builder.ConvertElementType(absi, F32))); - ComputeAndCompareR0(&builder, 8.0f, {}); + ComputeAndCompareR0(&builder, 8.5f, {}); } XLA_TEST_F(UnaryOpTest, SignTestR0) { ComputationBuilder builder(client_, TestName()); auto argi = builder.ConstantR0(-5); - auto absi = builder.Sign(argi); + auto sgni = builder.Sign(argi); // -1 auto argf = builder.ConstantR0(-4.0f); - auto absf = builder.Sign(argf); + auto sgnf = builder.Sign(argf); // -1 auto argf0 = builder.ConstantR0(-0.0f); - auto absf0 = builder.Sign(argf0); - builder.Add(absf0, builder.Add(absf, builder.ConvertElementType( - absi, PrimitiveType::F32))); - - ComputeAndCompareR0(&builder, -2.0f, {}); + auto sgnf0 = builder.Sign(argf0); // 0 + auto argc = builder.ConstantR0({-.3, .4}); + auto sgnc = builder.Sign(argc); // (-.6, .8) + builder.Add(sgnc, builder.ConvertElementType( + builder.Add(builder.Add(sgnf0, sgnf), + builder.ConvertElementType(sgni, F32)), + C64)); + + std::unique_ptr expected = + Literal::CreateR0({-2.6f, 0.8f}); + ComputeAndCompareLiteral(&builder, *expected, {}, ErrorSpec(1e-6f)); } XLA_TEST_F(UnaryOpTest, SignTestR1) { SignTestHelper(); SignTestHelper(); + SignTestHelper(); } XLA_TEST_F(UnaryOpTest, SignAbsTestR1) { SignAbsTestHelper(); SignAbsTestHelper(); + SignAbsTestHelper(); } XLA_TEST_F(UnaryOpTest, UnsignedAbsTestR1) { diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc index 48a85f16a22cd7536222b8c03c4ebad2bb77d240..b52c718814d4ffeff68c60588a6637a2159d57e5 100644 --- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc +++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc @@ -195,7 +195,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) { {2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); auto y = builder.ConstantR1( {-0.4, -0.6, -3.0, 0.2, 3.8, -2.2, -1.8, 4.9, 1.4, 0.6}); - auto max = builder.Map({x, y}, add); + auto max = builder.Map({x, y}, add, {0}); std::vector expected = {1.7, -3.2, -0.4, -3.8, 5.9, 0.1, -6.8, 4., -1., 2.2}; @@ -385,8 +385,8 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { auto two = builder.ConstantR0(2.0); auto max = builder.Max(z_value, zero); auto mult = builder.Mul(two, max); - auto inner = builder.Map({mult}, add_half); - builder.Map({inner}, clamp); + auto inner = builder.Map({mult}, add_half, {}); + builder.Map({inner}, clamp, {}); auto computation_status = builder.Build(); ASSERT_IS_OK(computation_status.status()); mult_relu_add = computation_status.ConsumeValueOrDie(); @@ -396,7 +396,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) { { auto x = builder.ConstantR1( {2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6}); - auto activations = builder.Map({x}, mult_relu_add); + auto activations = builder.Map({x}, mult_relu_add, {0}); } std::vector expected = {4.7, 0.5, 5.0, 0.5, 4.7, diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc index bb2d90fa94abbf52c340d366ddc55f7bdefb6543..3b29a2eb9e04cc8f5bd55be00bfc6e6ad0b985c2 100644 --- a/tensorflow/compiler/xla/tests/while_test.cc +++ b/tensorflow/compiler/xla/tests/while_test.cc @@ -169,7 +169,7 @@ TEST_F(WhileTest, WhileWithPredicateResult) { { ComputationBuilder builder(client_, "body"); auto prev = builder.Parameter(0, result_shape, "prev"); - auto result = builder.LogicalOr(prev, builder.ConstantR0(true)); + auto result = builder.Or(prev, builder.ConstantR0(true)); body = builder.Build().ConsumeValueOrDie(); } @@ -357,6 +357,111 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) { ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); } +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable and permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto result = builder.While(condition, body, init); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + + auto expected_counter = Literal::CreateR0(N); + auto expected_w1 = Literal::CreateR1({1.0f, 1.0f, 1.0f}); + auto expected_w2 = Literal::CreateR1({2.0f, 2.0f, 2.0f}); + auto expected_w3 = Literal::CreateR1({3.0f, 3.0f, 3.0f}); + auto expected = Literal::MakeTuple({expected_counter.get(), expected_w2.get(), + expected_w3.get(), expected_w1.get()}); + VLOG(2) << "expected = " << ShapeUtil::HumanString(expected->shape()); + ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001)); +} + +// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result. +TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) { + std::vector shape_elements = { + ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}), + ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})}; + Shape result_shape = ShapeUtil::MakeTupleShape(shape_elements); + + // Create a computation for the condition. + // Repeat for N iterations. + const int N = 2; + Computation condition; + { + ComputationBuilder builder(client_, "condition"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + builder.Gt(builder.ConstantR0(N), iteration); + condition = builder.Build().ConsumeValueOrDie(); + } + + // Create a computation for the body. + // Add 1 to the iteration variable permute the weights. + Computation body; + { + ComputationBuilder builder(client_, "body"); + auto prev = builder.Parameter(0, result_shape, "prev"); + auto iteration = builder.GetTupleElement(prev, 0); + auto w1 = builder.GetTupleElement(prev, 1); + auto w2 = builder.GetTupleElement(prev, 2); + auto w3 = builder.GetTupleElement(prev, 3); + auto result = builder.Tuple( + {builder.Add(iteration, builder.ConstantR0(1)), w3, w1, w2}); + body = builder.Build().ConsumeValueOrDie(); + } + + // Create a While node with computations for the condition and the body. + ComputationBuilder builder(client_, "while"); + auto init = builder.Tuple( + {builder.ConstantR0(0), builder.ConstantR1(3, 1.f), + builder.ConstantR1(3, 2.f), builder.ConstantR1(3, 3.f)}); + auto xla_while = builder.While(condition, body, init); + + auto add12 = builder.Add(builder.GetTupleElement(xla_while, 1), + builder.GetTupleElement(xla_while, 2)); + auto result = builder.Add(add12, builder.GetTupleElement(xla_while, 3)); + VLOG(2) << "result = " + << ShapeUtil::HumanString( + *builder.GetShape(result).ConsumeValueOrDie()); + std::vector expected = {6.f, 6.f, 6.f}; + ComputeAndCompareR1(&builder, expected, {}, ErrorSpec(0.0001)); +} + // Tests a while node when the result type T is a Tuple. // // tuple> result(0, vector(10, 0.0f)); @@ -437,7 +542,7 @@ TEST_F(WhileTest, WhileWithPredicateTupleResult) { auto prev = builder.Parameter(0, result_shape, "prev"); auto iteration = builder.GetTupleElement(prev, 0); auto pred = builder.GetTupleElement(prev, 1); - auto new_pred = builder.LogicalOr(pred, builder.ConstantR0(true)); + auto new_pred = builder.Or(pred, builder.ConstantR0(true)); auto result = builder.Tuple( {builder.Add(iteration, builder.ConstantR0(1)), new_pred}); body = builder.Build().ConsumeValueOrDie(); diff --git a/tensorflow/compiler/xla/tools/BUILD b/tensorflow/compiler/xla/tools/BUILD index ff350b92e8b6573b09d046afed687f821878571a..759921dce5acf3cd23a121776f3ab0731c9bb623 100644 --- a/tensorflow/compiler/xla/tools/BUILD +++ b/tensorflow/compiler/xla/tools/BUILD @@ -49,6 +49,7 @@ tf_cc_binary( name = "dumped_computation_to_graphviz", deps = [ ":dumped_computation_to_graphviz_library", + "//tensorflow/compiler/xla/service:interpreter_plugin", ], ) @@ -64,6 +65,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:client_library", "//tensorflow/compiler/xla/client:computation", "//tensorflow/compiler/xla/client:local_client", + "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -164,6 +166,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:computation_tracker", + "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -182,6 +185,7 @@ tf_cc_binary( "//tensorflow/compiler/xla/client:local_client", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], @@ -200,11 +204,24 @@ tf_cc_binary( "//tensorflow/compiler/xla/legacy_flags:debug_options_flags", "//tensorflow/compiler/xla/service", "//tensorflow/compiler/xla/service:hlo_graph_dumper", + "//tensorflow/compiler/xla/service:interpreter_plugin", "//tensorflow/compiler/xla/service:session_proto", "//tensorflow/core:lib", ], ) +tf_cc_binary( + name = "hlo_proto_to_json", + srcs = ["hlo_proto_to_json.cc"], + deps = [ + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla/service:hlo_proto", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + ], +) + # ----------------------------------------------------------------------------- filegroup( diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc index 6c952b29e28a43619e57f8f5d03237ed64d7e8ac..5ede37b8737bd4fa6235464ddeb6382af17c8a80 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_operation_list.cc @@ -86,14 +86,14 @@ void RealMain(tensorflow::gtl::ArraySlice args) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); OperationDumper dumper(arg); - for (auto& computation : module.computations()) { + for (auto* computation : module.computations()) { TF_CHECK_OK(computation->Accept(&dumper)); } } diff --git a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc index 2a3a8803283c62d12d8e2d213aa1730e8bd33244..78d8fb1f4330aed899ca917e66fae819a002b3a9 100644 --- a/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc +++ b/tensorflow/compiler/xla/tools/dumped_computation_to_text.cc @@ -61,9 +61,9 @@ void RealMain(tensorflow::gtl::ArraySlice args, bool compile) { layouts.push_back(&program_shape->parameters(i)); } StatusOr> executable = - local_service->CompileExecutable( - computation.handle(), layouts, &program_shape->result(), - /*device_ordinal=*/0, /*has_hybrid_result=*/true); + local_service->CompileExecutable(computation.handle(), layouts, + &program_shape->result(), + /*device_ordinal=*/0); const HloModule& module = executable.ValueOrDie()->module(); diff --git a/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc new file mode 100644 index 0000000000000000000000000000000000000000..4e02e17db65c0a4220672733be8319e1a0cc4f0f --- /dev/null +++ b/tensorflow/compiler/xla/tools/hlo_proto_to_json.cc @@ -0,0 +1,91 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Usage: +// hlo_proto_to_json --input_file=some_binary_proto +// --output_file=path_to_dump_output +// +// Reads one serilized Hlo module, convert it into JSON format and dump into +// some output directory. some_binaray_proto is obtained by serializing Hlo +// module to disk using --xla_dump_hlo_proto_to debug optoin. + +#include +#include +#include + +#include "tensorflow/compiler/xla/service/hlo.pb.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/command_line_flags.h" + +using tensorflow::Env; +using xla::string; + +namespace xla { +namespace tools { + +StatusOr ToJson(const tensorflow::protobuf::Message& message) { + string json_output; + tensorflow::protobuf::util::JsonPrintOptions json_options; + json_options.add_whitespace = true; + json_options.always_print_primitive_fields = true; + auto status = tensorflow::protobuf::util::MessageToJsonString( + message, &json_output, json_options); + if (!status.ok()) { + return InternalError("MessageToJsonString failed: %s", + status.error_message().data()); + } + return json_output; +} + +void RealMain(const string& input, const string& output) { + HloProto hlo_proto; + TF_CHECK_OK(tensorflow::ReadBinaryProto(tensorflow::Env::Default(), input, + &hlo_proto)) + << "Can't open, read, or parse input file " << input; + + auto statusor = ToJson(hlo_proto); + QCHECK(statusor.ok()) << "Error converting " << input << " to JSON." + << statusor.status(); + + TF_CHECK_OK(tensorflow::WriteStringToFile(tensorflow::Env::Default(), output, + statusor.ValueOrDie())); +} + +} // namespace tools +} // namespace xla + +int main(int argc, char** argv) { + string input_file, output_file; + const std::vector flag_list = { + tensorflow::Flag("input_file", &input_file, "file to convert."), + tensorflow::Flag("output_file", &output_file, "converted file"), + }; + const string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parse_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + tensorflow::port::InitMain(usage.c_str(), &argc, &argv); + QCHECK(parse_ok && argc == 1) << "\n" << usage; + + QCHECK(!input_file.empty()) << "--input_file is required"; + QCHECK(!output_file.empty()) << "--output_file is required"; + + xla::tools::RealMain(input_file, output_file); + + return 0; +} diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ce936af6c3376387c1ed9fa48da23b8af537f6e5 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -0,0 +1,86 @@ +# Build file for the Hlo parser. + +licenses(["notice"]) # Apache 2.0 + +package( + default_visibility = [":friends"], +) + +package_group( + name = "friends", + includes = [ + "//tensorflow/compiler/xla:friends", + ], +) + +# Filegroup used to collect source files for dependency checking. +filegroup( + name = "c_srcs", + data = glob([ + "**/*.cc", + "**/*.h", + ]), +) + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "hlo_lexer", + srcs = ["hlo_lexer.cc"], + hdrs = [ + "hlo_lexer.h", + "hlo_token.h", + ], + deps = [ + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:regexp_internal", + ], +) + +cc_library( + name = "hlo_parser", + srcs = ["hlo_parser.cc"], + hdrs = ["hlo_parser.h"], + deps = [ + ":hlo_lexer", + "//tensorflow/compiler/xla:literal_util", + "//tensorflow/compiler/xla:shape_util", + "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:util", + "//tensorflow/compiler/xla:xla_data_proto", + "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + ], +) + +tf_cc_test( + name = "hlo_parser_test", + size = "small", + srcs = ["hlo_parser_test.cc"], + deps = [ + ":hlo_parser", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + +# ----------------------------------------------------------------------------- + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/xla/tools/parser/README.md b/tensorflow/compiler/xla/tools/parser/README.md new file mode 100644 index 0000000000000000000000000000000000000000..2c864d77a20207bab7c72b207b31c9b886441e9b --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/README.md @@ -0,0 +1,106 @@ +# HloModule string syntax + +TODO: Support all subcomputations (for fusion, reduce, ...). + +TODO: Support all extra attributes, e.g. dimensions, strides. + +```yacc +hlo_module + : 'HloModule' name computations + ; + +computations + : computation + | computation computations + ; + +computation + : 'ENTRY' name param_list '->' shape instruction_list + | name param_list '->' shape instruction_list + ; + +instruction_list + : '{' instruction_list1 '}' + ; +instruction_list1 + : instruction + | instruction_list1 instruction + ; +instruction + : 'ROOT' name '=' shape opcode operands extra_attributes + | name '=' shape opcode operands extra_attributes + ; + +operands + : '(' operands1 ')' + ; +operands1 + : /*empty*/ + | operand + | operands1 ',' operand + ; +operand + : shape name + ; + +extra_attributes + : /*empty*/ + | ',' extra_attribute + | ',' extra_attribute extra_attributes + ; +extra_attribute + : attribute_name attribute_value + ; + +param_list + : '(' param_list1 ')' + ; +param_list1 + : /*empty*/ + | param + | param_list1 ',' param + ; +param + : name shape + ; + +shape + : shape_val_ + | '(' tuple_elements ')' + ; +tuple_elements + : /*empty*/ + | shape (',' shape)* + ; + +name + : identifier ':' + | '%' identifier + ; + +identifier + : [a-zA-Z_][a-zA-Z0-9_.-]* + ; + +/* literal is in the right hand side of a constant instruction. */ +literal + : tuple + | non_tuple + ; +tuple + : shape '(' literal_list ')' + ; +literal_list + : /*empty*/ + : literal + | literal_list ',' literal + ; +non_tuple + : rank01 + | rank2345 + ; +rank2345 + : shape nested_array + ; + +``` diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc new file mode 100644 index 0000000000000000000000000000000000000000..d104ff34601216bbaf5d5c068e00a7191a9b3b17 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -0,0 +1,365 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" + +#include + +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/optional.h" +#include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/platform/regexp.h" + +namespace xla { +namespace tools { + +using tensorflow::StringPiece; + +namespace { + +constexpr int kEOF = -1; +constexpr int kError = -2; + +// [a-zA-Z0-9_.-] +bool IsIdentifierChar(char c) { + return isalnum(static_cast(c)) || c == '-' || c == '.' || + c == '_'; +} + +} // namespace + +int HloLexer::GetNextChar() { + int current_char = PeekCurrentChar(); + if (current_char != kEOF && current_char != kError) { + current_ptr_++; + } + return current_char; +} + +int HloLexer::PeekCurrentChar() const { + if (current_ptr_ == buf_.end()) { + return kEOF; + } + char current_char = *current_ptr_; + if (current_char == 0) { + // '\0' should not appear in the middle of the string. + return kError; + } + return static_cast(current_char); +} + +bool HloLexer::CanDereference(const char* ptr) const { + return ptr < buf_.end() && ptr >= buf_.begin(); +} + +StringPiece HloLexer::StringPieceFromPointers(const char* begin, + const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return StringPiece(begin, end - begin); +} + +tensorflow::RegexpStringPiece HloLexer::RegexpStringPieceFromPointers( + const char* begin, const char* end) const { + CHECK(begin <= end); + CHECK(begin == buf_.end() || CanDereference(begin)); + CHECK(end == buf_.end() || CanDereference(end)); + return tensorflow::RegexpStringPiece(begin, end - begin); +} + +TokKind HloLexer::LexToken() { + while (true) { + token_start_ = current_ptr_; + + int current_char = GetNextChar(); + switch (current_char) { + default: + // [a-zA-Z_] + if (isalpha(static_cast(current_char)) || + current_char == '_') { + return LexIdentifier(); + } + return TokKind::kError; + case kEOF: + // Hit the end of the input buffer. + return TokKind::kEof; + case kError: + // Hit an invalid character in the input buffer. + return TokKind::kError; + case ' ': + case '\t': + case '\n': + case '\r': + // Ignore whitespace. + continue; + case '0': + case '1': + case '2': + case '3': + case '4': + case '5': + case '6': + case '7': + case '8': + case '9': + case '-': + if (current_char == '-' && PeekCurrentChar() == '>') { + current_ptr_++; + return TokKind::kArrow; + } + return LexDigitOrNegative(); + case '=': + return TokKind::kEqual; + case ',': + return TokKind::kComma; + case '%': + return LexPercent(); + case ':': + return TokKind::kColon; + case '[': + return TokKind::kLsquare; + case ']': + return TokKind::kRsquare; + case '{': + return TokKind::kLbrace; + case '}': + return TokKind::kRbrace; + case '(': + return TokKind::kLparen; + case ')': + return TokKind::kRparen; + case '/': + return LexComment(); + } + } +} + +// Lex a shape, name, keyword, or opcode. +// shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: +// keyword ::= HloModule, ENTRY, ... +// opcode ::= add, greater-than, ... +// attribute_name ::= condition, body, dimensions, ... +TokKind HloLexer::LexIdentifier() { + { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + // 'consumable' will be advanced iff its prefix matches the pattern. + static LazyRE2 shape_pattern = { + R"(^(\w*\d*)\[([\d,]*)\](?:{([\d,]*)})?)"}; + if (RE2::Consume(&consumable, *shape_pattern)) { + auto status_or_shape = ShapeUtil::ParseShapeString( + StringPieceFromPointers(token_start_, consumable.begin())); + if (status_or_shape.ok()) { + // This is a shape string. + shape_val_ = status_or_shape.ValueOrDie(); + current_ptr_ = consumable.begin(); + return TokKind::kShape; + } + } + } + + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + + // If followed by ':', it's a name. + if (PeekCurrentChar() == ':') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip ':' + return TokKind::kName; + } + + // If followed by '=', it's a attribute name. + if (PeekCurrentChar() == '=') { + str_val_.assign(token_start_, current_ptr_); + current_ptr_++; // skip '=' + return TokKind::kAttributeName; + } + + StringPiece identifier = StringPieceFromPointers(token_start_, current_ptr_); + + // See if this is a keyword. +#define KEYWORD(STR) \ + do { \ + if (identifier == #STR) { \ + return TokKind::kw_##STR; \ + } \ + } while (false) + + KEYWORD(true); + KEYWORD(false); + KEYWORD(inf); + KEYWORD(nan); + KEYWORD(HloModule); + KEYWORD(ENTRY); + KEYWORD(ROOT); + KEYWORD(maximal); + KEYWORD(replicated); + +#undef KEYWORD + + // See if this is an opcode. + auto opcode = StringToHloOpcode(identifier.ToString()); + if (opcode.ok()) { + opcode_val_ = opcode.ValueOrDie(); + return TokKind::kOpcode; + } + + current_ptr_ = token_start_ + 1; + return TokKind::kError; +} + +// Lex names after a % character. +// name ::= [a-zA-Z_][a-zA-Z0-9_.-]* +TokKind HloLexer::LexPercent() { + const char* name_start = current_ptr_; + if (isalpha(static_cast(PeekCurrentChar())) || + PeekCurrentChar() == '_') { + current_ptr_++; + while (IsIdentifierChar(PeekCurrentChar())) { + current_ptr_++; + } + str_val_.assign(name_start, current_ptr_); + return TokKind::kName; + } + return TokKind::kError; +} + +// Lex integer and floating-point values, and -inf. +// int [-]?[0-9]+ +// fp with exp [-]?([0-9]+|[0-9]+[.][0-9]*|[0-9]*[.][0-9]+)([eE][+-]?[0-9]+) +// fp without exp [-]?([0-9]+[.][0-9]*|[0-9]*[.][0-9]+) +// negative inf -inf +TokKind HloLexer::LexDigitOrNegative() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 float_pattern = { + R"([-]?((\d+|\d+[.]\d*|\d*[.]\d+)([eE][+-]?\d+))|(\d+[.]\d*|\d*[.]\d+))"}; + if (RE2::Consume(&consumable, *float_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strtod(string(token_start_, current_ptr_).c_str(), + &decimal_val_); + return TokKind::kDecimal; + } + + static LazyRE2 int_pattern = {R"([-]?\d+)"}; + if (RE2::Consume(&consumable, *int_pattern)) { + current_ptr_ = consumable.begin(); + tensorflow::strings::safe_strto64( + StringPieceFromPointers(token_start_, current_ptr_), &int64_val_); + return TokKind::kInt; + } + + static LazyRE2 neg_inf = {"-inf"}; + if (RE2::Consume(&consumable, *neg_inf)) { + current_ptr_ = consumable.begin(); + return TokKind::kNegInf; + } + + return TokKind::kError; +} + +StringPiece HloLexer::GetCurrentLine() const { + const char* start = token_start_; + const char* end = current_ptr_; + if (!CanDereference(start) || !CanDereference(end)) { + return "LINE OUT OF RANGE"; + } + while (start > buf_.begin() && *start != '\n') { + start--; + } + while (end < buf_.end() && *end != '\n') { + end++; + } + return StringPieceFromPointers(start, end); +} + +TokKind HloLexer::LexComment() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 comment_pattern = {R"(\/\*.*?\*\/)"}; + if (RE2::Consume(&consumable, *comment_pattern)) { + current_ptr_ = consumable.begin(); + return TokKind::kComment; + } + return TokKind::kError; +} + +string TokKindToString(TokKind kind) { + switch (kind) { + case TokKind::kEof: + return "kEof"; + case TokKind::kError: + return "kError"; + case TokKind::kEqual: + return "kEqaul"; + case TokKind::kComma: + return "kComma"; + case TokKind::kColon: + return "kColon"; + case TokKind::kLsquare: + return "kLsquare"; + case TokKind::kRsquare: + return "kRsquare"; + case TokKind::kLbrace: + return "kLbrace"; + case TokKind::kRbrace: + return "kRbrace"; + case TokKind::kLparen: + return "kLparen"; + case TokKind::kRparen: + return "kRparen"; + case TokKind::kArrow: + return "kArrow"; + case TokKind::kComment: + return "kComment"; + case TokKind::kw_HloModule: + return "kw_HloModule"; + case TokKind::kw_ENTRY: + return "kw_ENTRY"; + case TokKind::kw_ROOT: + return "kw_ROOT"; + case TokKind::kw_true: + return "kw_true"; + case TokKind::kw_false: + return "kw_false"; + case TokKind::kw_maximal: + return "kw_maximal"; + case TokKind::kw_replicated: + return "kw_replicated"; + case TokKind::kw_nan: + return "kw_nan"; + case TokKind::kw_inf: + return "kw_inf"; + case TokKind::kNegInf: + return "kNegInf"; + case TokKind::kName: + return "kName"; + case TokKind::kAttributeName: + return "kAttributeName"; + case TokKind::kShape: + return "kShape"; + case TokKind::kOpcode: + return "kOpcode"; + case TokKind::kInt: + return "kInt"; + case TokKind::kDecimal: + return "kDecimal"; + } +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h new file mode 100644 index 0000000000000000000000000000000000000000..3b9efcb92d074a234868a12b8f4dc5db867ea1ec --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -0,0 +1,114 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ + +#include + +#include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/types.h" + +namespace xla { +namespace tools { + +// Lexer for the HloModule::ToString() format text. +class HloLexer { + public: + explicit HloLexer(tensorflow::StringPiece buf) : buf_(buf) { + current_ptr_ = buf_.begin(); + } + + TokKind Lex() { return current_kind_ = LexToken(); } + TokKind GetKind() const { return current_kind_; } + string GetStrVal() const { + switch (GetKind()) { + case TokKind::kName: + case TokKind::kAttributeName: + return str_val_; + default: + LOG(FATAL) << "This token does not have string value"; + } + } + Shape GetShapeVal() const { + CHECK(GetKind() == TokKind::kShape); + return shape_val_; + } + HloOpcode GetOpcodeVal() const { + CHECK(GetKind() == TokKind::kOpcode); + return opcode_val_; + } + int64 GetInt64Val() const { + CHECK(GetKind() == TokKind::kInt); + return int64_val_; + } + double GetDecimalVal() const { + CHECK(GetKind() == TokKind::kDecimal); + return decimal_val_; + } + + // Returns the line of text that is currently being lexed. + tensorflow::StringPiece GetCurrentLine() const; + + private: + // Returns the current character. If it's neither the end of input buffer nor + // an invalid character, moves the pointer forward. + int GetNextChar(); + + // Returns the current character. + int PeekCurrentChar() const; + + // Creates StringPiece with the given begin and end. Exits if the begin > end, + // or it's out of the range of the current buffer. + tensorflow::StringPiece StringPieceFromPointers(const char* begin, + const char* end) const; + tensorflow::RegexpStringPiece RegexpStringPieceFromPointers( + const char* begin, const char* end) const; + + // Returns true if the given ptr is dereferenceable within the range of the + // current buffer. + bool CanDereference(const char* ptr) const; + + TokKind LexToken(); + + TokKind LexIdentifier(); + TokKind LexPercent(); + TokKind LexShape(); + TokKind LexConstant(); + TokKind LexDigitOrNegative(); + TokKind LexComment(); + + const tensorflow::StringPiece buf_; + const char* current_ptr_; + + // Information about the current token. + const char* token_start_; + TokKind current_kind_; + string str_val_; + Shape shape_val_; + HloOpcode opcode_val_; + int64 int64_val_; + double decimal_val_; +}; + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_LEXER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c2e37e3b5cdd73157279fb171d3332aa9854184 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -0,0 +1,1179 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/shape_util.h" +#include "tensorflow/compiler/xla/util.h" +#include "tensorflow/core/lib/gtl/map_util.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/lib/strings/stringprintf.h" + +namespace xla { +namespace tools { + +namespace { + +using tensorflow::StringPiece; +using tensorflow::strings::Printf; +using tensorflow::strings::StrAppend; +using tensorflow::strings::StrCat; + +const double kF16max = 65504; + +// Parser for the HloModule::ToString() format text. +class HloParser { + public: + explicit HloParser(StringPiece str, const HloModuleConfig& config) + : lexer_(str), config_(config) {} + + // Runs the parser. Returns false if an error occurred. + bool Run(); + + // Returns the parsed HloModule. + std::unique_ptr ConsumeHloModule() { return std::move(module_); } + + // Returns the error information. + string GetError() const { return tensorflow::str_util::Join(error_, "\n"); } + + private: + // ParseXXX returns false if an error occurred. + bool ParseHloModule(); + bool ParseComputations(); + bool ParseComputation(); + bool ParseInstructionList(HloComputation::Builder* builder, + string* root_name); + bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseSharding(HloInstruction* instruction); + bool ParseControlPredecessors(HloInstruction* instruction); + bool ParseLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseTupleLiteral(std::unique_ptr* literal, const Shape& shape); + bool ParseNonTupleLiteral(std::unique_ptr* literal, + const Shape& shape); + // Sets the sub-value of literal at the given index to the given value. The + // literal's shape must have the default layout. + bool SetValueInLiteral(int64 value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(double value, int64 linear_index, Literal* literal); + bool SetValueInLiteral(bool value, int64 linear_index, Literal* literal); + template + bool SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + Literal* literal); + + bool ParseOperands(std::vector* operands); + // Fills parsed operands into 'operands' and expects a certain number of + // operands. + bool ParseOperands(std::vector* operands, + const int expected_size); + + template + bool ParseExtraAttribute(T* value, const string& expected_attribute); + template + bool ParseAttributeValue(T* value); + + bool ParseParamList(); + bool ParseName(string* result); + bool ParseAttributeName(string* result); + bool ParseShape(Shape* result); + bool ParseOpcode(HloOpcode* result); + bool ParseInt64(int64* result); + bool ParseDouble(double* result); + bool ParseBool(bool* result); + bool ParseToken(TokKind kind, const string& msg); + + // Logs the current parsing line and the given message. Always returns false. + bool TokenError(StringPiece msg); + + // If the current token is 'kind', eats it (i.e. lexes the next token) and + // returns true. + bool EatIfPresent(TokKind kind); + // Parses a shape, and returns true if the result is compatible with the given + // shape. + bool EatShapeAndCheckCompatible(const Shape& shape); + + // Adds the instruction to the pool. Returns false and emits an error if the + // instruction already exists. + bool AddInstruction(const string& name, HloInstruction* instruction); + // Adds the computation to the pool. Returns false and emits an error if the + // computation already exists. + bool AddComputation(const string& name, HloComputation* computation); + + // The map from the instruction name to the instruction. This does not own the + // instructions. + std::unordered_map instruction_pool_; + std::unordered_map computation_pool_; + + HloLexer lexer_; + std::unique_ptr module_; + const HloModuleConfig config_; + std::vector error_; +}; + +bool HloParser::TokenError(StringPiece msg) { + const string error = + StrCat("was parsing \"", lexer_.GetCurrentLine(), "\"; token ", + TokKindToString(lexer_.GetKind()), "; ", msg); + VLOG(1) << "TokenError: " << error; + error_.push_back(error); + return false; +} + +bool HloParser::Run() { + lexer_.Lex(); + return ParseHloModule(); +} + +// ::= 'HloModule' name computations +bool HloParser::ParseHloModule() { + if (lexer_.GetKind() != TokKind::kw_HloModule) { + return TokenError("expects HloModule"); + } + // Eat 'HloModule' + lexer_.Lex(); + + string name; + if (!ParseName(&name)) { + return false; + } + + module_ = MakeUnique(name, config_); + + return ParseComputations(); +} + +// computations ::= (computation)+ +bool HloParser::ParseComputations() { + do { + if (!ParseComputation()) { + return false; + } + } while (lexer_.GetKind() != TokKind::kEof); + return true; +} + +// computation ::= ('ENTRY')? name param_list '->' shape instruction_list +bool HloParser::ParseComputation() { + const bool is_entry_computation = EatIfPresent(TokKind::kw_ENTRY); + string name; + if (!ParseName(&name)) { + return false; + } + auto builder = MakeUnique(name); + + Shape shape; + string root_name; + if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'") || + !ParseShape(&shape) || !ParseInstructionList(builder.get(), &root_name)) { + return false; + } + + HloInstruction* root = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, root_name); + // This means some instruction was marked as ROOT but we didn't find it in the + // pool, which should not happen. + if (!root_name.empty() && root == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // root instruction. + HloComputation* computation = + is_entry_computation + ? module_->AddEntryComputation(builder->Build(root)) + : module_->AddEmbeddedComputation(builder->Build(root)); + return AddComputation(name, computation); +} + +// instruction_list ::= '{' instruction_list1 '}' +// instruction_list1 ::= (instruction)+ +bool HloParser::ParseInstructionList(HloComputation::Builder* builder, + string* root_name) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of instruction list.")) { + return false; + } + do { + if (!ParseInstruction(builder, root_name)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list."); +} + +// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* +bool HloParser::ParseInstruction(HloComputation::Builder* builder, + string* root_name) { + string name; + Shape shape; + HloOpcode opcode; + std::vector operands; + bool is_root = EatIfPresent(TokKind::kw_ROOT); + if (!ParseName(&name) || + !ParseToken(TokKind::kEqual, "expects '=' in instruction") || + !ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + if (is_root) { + *root_name = name; + } + HloInstruction* instruction; + switch (opcode) { + case HloOpcode::kParameter: { + int64 parameter_number; + if (!ParseToken(TokKind::kLparen, + "expects '(' before parameter number") || + !ParseInt64(¶meter_number) || + !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_number, shape, name)); + break; + } + case HloOpcode::kConstant: { + std::unique_ptr literal; + if (!ParseToken(TokKind::kLparen, + "expects '(' before constant literal") || + !ParseLiteral(&literal, shape) || + !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateConstant(std::move(literal))); + break; + } + // Unary ops. + case HloOpcode::kAbs: + case HloOpcode::kRoundNearestAfz: + case HloOpcode::kBitcast: + case HloOpcode::kCeil: + case HloOpcode::kCopy: + case HloOpcode::kCos: + case HloOpcode::kExp: + case HloOpcode::kImag: + case HloOpcode::kIsFinite: + case HloOpcode::kFloor: + case HloOpcode::kLog: + case HloOpcode::kNot: + case HloOpcode::kNegate: + case HloOpcode::kReal: + case HloOpcode::kSign: + case HloOpcode::kSin: + case HloOpcode::kSort: + case HloOpcode::kTanh: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateUnary(shape, opcode, operands[0])); + break; + } + // Binary ops. + case HloOpcode::kAdd: + case HloOpcode::kDivide: + case HloOpcode::kMultiply: + case HloOpcode::kSubtract: + case HloOpcode::kAtan2: + case HloOpcode::kComplex: + case HloOpcode::kEq: + case HloOpcode::kGe: + case HloOpcode::kGt: + case HloOpcode::kLe: + case HloOpcode::kLt: + case HloOpcode::kNe: + case HloOpcode::kDot: + case HloOpcode::kMaximum: + case HloOpcode::kMinimum: + case HloOpcode::kPower: + case HloOpcode::kRemainder: + case HloOpcode::kAnd: + case HloOpcode::kOr: + case HloOpcode::kShiftLeft: + case HloOpcode::kShiftRightArithmetic: + case HloOpcode::kShiftRightLogical: { + if (!ParseOperands(&operands, /*expected_size=*/2)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateBinary( + shape, opcode, operands[0], operands[1])); + break; + } + // Ternary ops. + case HloOpcode::kClamp: + case HloOpcode::kSelect: { + if (!ParseOperands(&operands, /*expected_size=*/3)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateTernary( + shape, opcode, operands[0], operands[1], operands[2])); + break; + } + // Other supported ops. + case HloOpcode::kConvert: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateConvert(shape, operands[0])); + break; + } + case HloOpcode::kCrossReplicaSum: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCrossReplicaSum(shape, operands[0])); + break; + } + case HloOpcode::kReshape: { + if (!ParseOperands(&operands, /*expected_size=*/1)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateReshape(shape, operands[0])); + break; + } + case HloOpcode::kTuple: { + if (!ParseOperands(&operands)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateTuple(operands)); + break; + } + case HloOpcode::kWhile: { + HloComputation* condition; + HloComputation* body; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&condition, + /*expected_attribute=*/"condition") || + !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateWhile( + shape, condition, body, /*init=*/operands[0])); + break; + } + case HloOpcode::kRecv: { + int64 channel_id; + if (!ParseOperands(&operands, /*expected_size=*/0) || + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRecv(shape, channel_id)); + break; + } + case HloOpcode::kSend: { + int64 channel_id; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateSend(operands[0], channel_id)); + break; + } + case HloOpcode::kGetTupleElement: { + int64 index; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateGetTupleElement(shape, operands[0], index)); + break; + } + case HloOpcode::kCall: { + HloComputation* to_apply; + if (!ParseOperands(&operands) || + !ParseExtraAttribute(&to_apply, + /*expected_attribute=*/"to_apply")) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateCall(shape, operands, to_apply)); + break; + } + case HloOpcode::kBroadcast: + case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: + case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: + case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: + case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: + case HloOpcode::kInfeed: + case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: + case HloOpcode::kTrace: + return TokenError(StrCat("parsing not yet implemented for op: ", + HloOpcodeString(opcode))); + } + + bool has_sharding = false; + bool has_control = false; + while (EatIfPresent(TokKind::kComma)) { + string attribute_name; + if (!ParseAttributeName(&attribute_name)) { + return TokenError("expects ', sharding=' or ', control-predecessors='"); + } + + if (attribute_name == "sharding") { + // Parse "sharding=". + if (has_sharding) { + return TokenError("expects at most 1 'sharding='"); + } + has_sharding = true; + if (!ParseSharding(instruction)) { + return false; + } + } else if (attribute_name == "control-predecessors") { + // Parse "control-predecessors" + if (has_control) { + return TokenError("expects at most 1 'control-predecessors='"); + } + has_control = true; + if (!ParseControlPredecessors(instruction)) { + return false; + } + } else { + return TokenError(StrCat("unexpected attribute: ", attribute_name)); + } + } + + return AddInstruction(name, instruction); +} + +// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' +// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list +bool HloParser::ParseSharding(HloInstruction* instruction) { + if (!ParseToken(TokKind::kLbrace, + "expected '{' to start sharding attribute")) { + return false; + } + + bool maximal = false; + bool replicated = false; + std::vector devices; + std::vector tile_assignment_dimensions; + Shape tile_shape; + while (lexer_.GetKind() != TokKind::kRbrace) { + switch (lexer_.GetKind()) { + case TokKind::kw_maximal: + maximal = true; + lexer_.Lex(); + break; + case TokKind::kw_replicated: + replicated = true; + lexer_.Lex(); + break; + case TokKind::kAttributeName: { + if (lexer_.GetStrVal() == "device") { + if (lexer_.Lex() != TokKind::kInt) { + return TokenError("device= attribute must be an integer"); + } + devices = {lexer_.GetInt64Val()}; + lexer_.Lex(); + } else if (lexer_.GetStrVal() == "devices") { + lexer_.Lex(); + if (!ParseToken(TokKind::kLsquare, + "expected '[' to start sharding devices shape")) { + return false; + } + + do { + int64 dim; + if (!ParseInt64(&dim)) { + return false; + } + tile_assignment_dimensions.push_back(dim); + } while (EatIfPresent(TokKind::kComma)); + + if (!ParseToken(TokKind::kRsquare, + "expected ']' to start sharding devices shape")) { + return false; + } + do { + int64 device; + if (!ParseInt64(&device)) { + return false; + } + devices.push_back(device); + } while (EatIfPresent(TokKind::kComma)); + } else { + return TokenError( + "unknown attribute in sharding: expected device= or devices="); + } + break; + } + case TokKind::kShape: + tile_shape = lexer_.GetShapeVal(); + lexer_.Lex(); + break; + case TokKind::kRbrace: + break; + default: + return TokenError("unexpected token"); + } + } + + OpSharding sharding; + if (replicated) { + if (!devices.empty()) { + return TokenError( + "replicated shardings should not have any devices assigned"); + } + if (!ShapeUtil::Equal(tile_shape, Shape())) { + return TokenError( + "replicated shardings should not have any tile shape set"); + } + sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + } else if (maximal) { + if (devices.size() != 1) { + return TokenError( + "maximal shardings should have exactly one device assigned"); + } + if (!ShapeUtil::Equal(tile_shape, Shape())) { + return TokenError("maximal shardings should not have any tile shape set"); + } + sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding.add_tile_assignment_devices(devices[0]); + } else { + if (devices.size() <= 1) { + return TokenError( + "non-maximal shardings must have more than one device assigned"); + } + if (ShapeUtil::Equal(tile_shape, Shape())) { + return TokenError("non-maximal shardings should have a tile shape set"); + } + if (tile_assignment_dimensions.empty()) { + return TokenError( + "non-maximal shardings must have a tile assignment list including " + "dimensions"); + } + sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding.mutable_tile_shape() = tile_shape; + for (int64 dim : tile_assignment_dimensions) { + sharding.add_tile_assignment_dimensions(dim); + } + for (int64 device : devices) { + sharding.add_tile_assignment_devices(device); + } + } + + instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); + lexer_.Lex(); + return true; +} + +// '{' name+ '}' +bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { + if (!ParseToken(TokKind::kLbrace, + "expects '{' at the beginning of control predecessors")) { + return false; + } + do { + string name; + if (!ParseName(&name)) { + return TokenError("expects a control predecessor"); + } + HloInstruction* pre = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!pre) { + return TokenError( + StrCat("control predecessor ", name, " is not defined: ")); + } + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); + } + } while (EatIfPresent(TokKind::kComma)); + + return ParseToken(TokKind::kRbrace, + "expects '}' at the end of control predecessors"); +} + +bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case S8: + return SetValueInLiteralHelper(value, linear_index, literal); + case S16: + return SetValueInLiteralHelper(value, linear_index, literal); + case S32: + return SetValueInLiteralHelper(value, linear_index, literal); + case S64: + return SetValueInLiteralHelper(value, linear_index, literal); + case U8: + return SetValueInLiteralHelper(value, linear_index, literal); + case U16: + return SetValueInLiteralHelper(value, linear_index, literal); + case U32: + return SetValueInLiteralHelper(value, linear_index, literal); + case U64: + return SetValueInLiteralHelper(value, linear_index, literal); + default: + LOG(FATAL) << "unknown integral primitive type " + << PrimitiveType_Name(shape.element_type()); + } +} + +bool HloParser::SetValueInLiteral(double value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case F16: + return SetValueInLiteralHelper(value, linear_index, literal); + case F32: + return SetValueInLiteralHelper(value, linear_index, literal); + case F64: + return SetValueInLiteralHelper(value, linear_index, literal); + default: + LOG(FATAL) << "unknown floating point primitive type " + << PrimitiveType_Name(shape.element_type()); + } +} + +bool HloParser::SetValueInLiteral(bool value, int64 linear_index, + Literal* literal) { + const Shape& shape = literal->shape(); + switch (shape.element_type()) { + case PRED: + return SetValueInLiteralHelper(value, linear_index, literal); + default: + LOG(FATAL) << PrimitiveType_Name(shape.element_type()) + << " is not PRED type"; + } +} + +template +bool HloParser::SetValueInLiteralHelper(ParsedElemT value, int64 linear_index, + Literal* literal) { + // Check that linear_index is in range. + if (linear_index >= ShapeUtil::ElementsIn(literal->shape())) { + return TokenError( + StrCat("trys to set value ", value, " to a literal in shape ", + ShapeUtil::HumanString(literal->shape()), " at linear index ", + linear_index, ", but the index is out of range")); + } + + if (std::isnan(value) || + (std::numeric_limits::has_infinity && + (std::numeric_limits::infinity() == value || + -std::numeric_limits::infinity() == value))) { + // Skip range checking for non-finite value. + } else if (literal->shape().element_type() == F16) { + if (value > kF16max || value < -kF16max) { + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } + } else if (value > static_cast( + std::numeric_limits::max()) || + value < static_cast( + std::numeric_limits::lowest())) { + // Value is out of range for LiteralNativeT. + return TokenError(StrCat( + "value ", value, " is out of range for literal's primitive type ", + PrimitiveType_Name(literal->shape().element_type()))); + } + + literal->GetMutableArraySlice().at(linear_index) = + static_cast(value); + return true; +} + +bool HloParser::EatShapeAndCheckCompatible(const Shape& shape) { + Shape new_shape; + if (!ParseShape(&new_shape)) { + return TokenError(StrCat("expects shape ", ShapeUtil::HumanString(shape))); + } + if (!ShapeUtil::Compatible(shape, new_shape)) { + return TokenError(StrCat( + "expects shape ", ShapeUtil::HumanString(shape), + ", but sees a different shape: ", ShapeUtil::HumanString(new_shape))); + } + return true; +} + +// literal +// ::= tuple +// ::= non_tuple +bool HloParser::ParseLiteral(std::unique_ptr* literal, + const Shape& shape) { + return ShapeUtil::IsTuple(shape) ? ParseTupleLiteral(literal, shape) + : ParseNonTupleLiteral(literal, shape); +} + +// tuple +// ::= shape '(' literal_list ')' +// literal_list +// ::= /*empty*/ +// ::= literal (',' literal)* +bool HloParser::ParseTupleLiteral(std::unique_ptr* literal, + const Shape& shape) { + if (!EatShapeAndCheckCompatible(shape)) { + return TokenError(StrCat("expects tuple constant in shape ", + ShapeUtil::HumanString(shape))); + } + if (!ParseToken(TokKind::kLparen, "expects '(' in front of tuple elements")) { + return false; + } + std::vector> elements( + ShapeUtil::TupleElementCount(shape)); + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + // literal, (',' literal)* + for (int i = 0; i < elements.size(); i++) { + if (i > 0) { + ParseToken(TokKind::kComma, "exepcts ',' to separate tuple elements"); + } + if (!ParseLiteral(&elements[i], + ShapeUtil::GetTupleElementShape(shape, i))) { + return TokenError(StrCat("expects the ", i, "th element")); + } + } + } + *literal = Literal::MakeTupleOwned(std::move(elements)); + return ParseToken(TokKind::kRparen, + StrCat("expects ')' at the end of the tuple with ", + ShapeUtil::TupleElementCount(shape), "elements")); +} + +// non_tuple +// ::= rank01 +// ::= rank2345 +// rank2345 ::= shape nested_array +bool HloParser::ParseNonTupleLiteral(std::unique_ptr* literal, + const Shape& shape) { + const int64 size = ShapeUtil::ElementsIn(shape); + if (size == 0) { + *literal = Literal::CreateFromShape(shape); + return true; + } + + const int64 rank = ShapeUtil::Rank(shape); + if (rank > 1 && !EatShapeAndCheckCompatible(shape)) { + return false; + } + + // Create a literal with the given shape in default layout. + *literal = Literal::CreateFromDimensions(shape.element_type(), + AsInt64Slice(shape.dimensions())); + int64 nest_level = 0; + int64 linear_index = 0; + // elems_seen_per_dim[i] is how many elements or sub-arrays we have seen for + // the dimension i. For example, to parse f32[2,3] {{1, 2, 3}, {4, 5, 6}}, + // when we are parsing the 2nd '{' (right before '1'), we are seeing a + // sub-array of the dimension 0, so elems_seen_per_dim[0]++. When we are at + // the first '}' (right after '3'), it means the sub-array ends, and the + // sub-array is supposed to contain exactly 3 elements, so check if + // elems_seen_per_dim[1] is 3. + std::vector elems_seen_per_dim(rank); + auto get_index_str = [&elems_seen_per_dim](int dim) -> string { + std::vector elems_seen_until_dim(elems_seen_per_dim.begin(), + elems_seen_per_dim.begin() + dim); + return StrCat("[", + tensorflow::str_util::Join( + elems_seen_until_dim, ",", + [](string* out, const int64& num_elems) { + tensorflow::strings::StrAppend(out, num_elems - 1); + }), + "]"); + }; + do { + switch (lexer_.GetKind()) { + default: + return TokenError("unexpected token type in a literal"); + case TokKind::kLbrace: { + nest_level++; + if (nest_level > rank) { + return TokenError(Printf( + "expects nested array in rank %lld, but sees larger", rank)); + } + if (nest_level > 1) { + elems_seen_per_dim[nest_level - 2]++; + if (elems_seen_per_dim[nest_level - 2] > + shape.dimensions(nest_level - 2)) { + return TokenError(Printf( + "expects %lld elements in the %sth element, but sees more", + shape.dimensions(nest_level - 2), + get_index_str(nest_level - 2).c_str())); + } + } + lexer_.Lex(); + break; + } + case TokKind::kRbrace: { + nest_level--; + if (elems_seen_per_dim[nest_level] != shape.dimensions(nest_level)) { + return TokenError(Printf( + "expects %lld elements in the %sth element, but sees %lld", + shape.dimensions(nest_level), get_index_str(nest_level).c_str(), + elems_seen_per_dim[nest_level])); + } + elems_seen_per_dim[nest_level] = 0; + lexer_.Lex(); + break; + } + case TokKind::kComma: + case TokKind::kComment: + // Skip. + lexer_.Lex(); + break; + case TokKind::kw_true: + case TokKind::kw_false: + case TokKind::kInt: + case TokKind::kDecimal: + case TokKind::kw_nan: + case TokKind::kw_inf: + case TokKind::kNegInf: { + if (rank > 0) { + if (nest_level != rank) { + return TokenError( + Printf("expects nested array in rank %lld, but sees %lld", rank, + nest_level)); + } + elems_seen_per_dim[rank - 1]++; + if (elems_seen_per_dim[rank - 1] > shape.dimensions(rank - 1)) { + return TokenError( + Printf("expects %lld elements on the minor-most dimension, but " + "sees more", + shape.dimensions(rank - 1))); + } + } + if (lexer_.GetKind() == TokKind::kw_true || + lexer_.GetKind() == TokKind::kw_false) { + // TODO(congliu): bool type literals with rank >= 1 are actually + // printed in a compact form instead of "true" or "false". Fix that. + if (!SetValueInLiteral(lexer_.GetKind() == TokKind::kw_true, + linear_index++, literal->get())) { + return false; + } + lexer_.Lex(); + } else if (primitive_util::IsIntegralType(shape.element_type())) { + int64 value; + if (!ParseInt64(&value)) { + return TokenError(StrCat("expects integer for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value, linear_index++, literal->get())) { + return false; + } + } else if (primitive_util::IsFloatingPointType(shape.element_type())) { + double value; + if (!ParseDouble(&value)) { + return TokenError( + StrCat("expect floating point value for primitive type: ", + PrimitiveType_Name(shape.element_type()))); + } + if (!SetValueInLiteral(value, linear_index++, literal->get())) { + return false; + } + } else { + return TokenError(StrCat("unsupported premitive type ", + PrimitiveType_Name(shape.element_type()))); + } + break; + } + } // end of switch + } while (nest_level > 0); + + *literal = (*literal)->Relayout(shape.layout()); + return true; +} + +// operands ::= '(' operands1 ')' +// operands1 +// ::= /*empty*/ +// ::= operand (, operand)* +// operand ::= shape name +bool HloParser::ParseOperands(std::vector* operands) { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of operands")) { + return false; + } + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + string name; + if (!ParseShape(&shape) || !ParseName(&name)) { + return false; + } + HloInstruction* instruction = + tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); + if (!instruction) { + return TokenError(StrCat("instruction does not exist: ", name)); + } + operands->push_back(instruction); + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of operands"); +} + +bool HloParser::ParseOperands(std::vector* operands, + const int expected_size) { + if (!ParseOperands(operands)) { + return false; + } + if (expected_size != operands->size()) { + return TokenError(StrCat("expects ", expected_size, " operands, but has ", + operands->size(), " operands")); + } + return true; +} + +// extra_attribute ::= ',' attribute_name value +template +bool HloParser::ParseExtraAttribute(T* value, + const string& expected_attribute) { + if (!ParseToken(TokKind::kComma, + "expects ',' in front of an extra attribute")) { + return false; + } + string attribute_name; + if (!ParseAttributeName(&attribute_name) && + attribute_name != expected_attribute) { + return TokenError(StrCat("expects attribute name: ", expected_attribute)); + } + if (!ParseAttributeValue(value)) { + return TokenError( + StrCat("expects value for attribute: ", expected_attribute)); + } + return true; +} + +template <> +bool HloParser::ParseAttributeValue(HloComputation** value) { + string name; + if (!ParseName(&name)) { + return TokenError("expects computation name"); + } + *value = tensorflow::gtl::FindPtrOrNull(computation_pool_, name); + if (*value == nullptr) { + return TokenError(StrCat("computation does not exist: ", name)); + } + return true; +} + +template <> +bool HloParser::ParseAttributeValue(int64* value) { + return ParseInt64(value); +} + +// param_list ::= '(' param_list1 ')' +// param_list1 +// ::= /*empty*/ +// ::= param (',' param)* +// param ::= name shape +bool HloParser::ParseParamList() { + if (!ParseToken(TokKind::kLparen, + "expects '(' at the beginning of param list")) { + return false; + } + + if (lexer_.GetKind() == TokKind::kRparen) { + // empty + } else { + do { + Shape shape; + if (!ParseToken(TokKind::kName, "expects name in parameter") || + !ParseShape(&shape)) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + return ParseToken(TokKind::kRparen, "expects ')' at the end of param list"); +} + +// shape ::= shape_val_ +// shape ::= '(' tuple_elements ')' +// tuple_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShape(Shape* result) { + if (EatIfPresent(TokKind::kLparen)) { // Tuple + std::vector shapes; + if (lexer_.GetKind() == TokKind::kRparen) { + /*empty*/ + } else { + // shape (',' shape)* + do { + shapes.emplace_back(); + if (!ParseShape(&shapes.back())) { + return false; + } + } while (EatIfPresent(TokKind::kComma)); + } + *result = ShapeUtil::MakeTupleShape(shapes); + return ParseToken(TokKind::kRparen, "expects ')' at the end of tuple."); + } + + if (lexer_.GetKind() != TokKind::kShape) { + return TokenError("expects shape"); + } + *result = lexer_.GetShapeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseName(string* result) { + VLOG(1) << "ParseName"; + if (lexer_.GetKind() != TokKind::kName) { + return TokenError("expects name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseAttributeName(string* result) { + if (lexer_.GetKind() != TokKind::kAttributeName) { + return TokenError("expects attribute name"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseOpcode(HloOpcode* result) { + VLOG(1) << "ParseOpcode"; + if (lexer_.GetKind() != TokKind::kOpcode) { + return TokenError("expects opcode"); + } + *result = lexer_.GetOpcodeVal(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseInt64(int64* result) { + VLOG(1) << "ParseInt64"; + if (lexer_.GetKind() != TokKind::kInt) { + return TokenError("expects integer"); + } + *result = lexer_.GetInt64Val(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseDouble(double* result) { + switch (lexer_.GetKind()) { + case TokKind::kDecimal: + *result = lexer_.GetDecimalVal(); + break; + case TokKind::kInt: + *result = static_cast(lexer_.GetInt64Val()); + break; + case TokKind::kw_nan: + *result = std::numeric_limits::quiet_NaN(); + break; + case TokKind::kw_inf: + *result = std::numeric_limits::infinity(); + break; + case TokKind::kNegInf: + *result = -std::numeric_limits::infinity(); + break; + default: + return TokenError("expects decimal or integer"); + } + lexer_.Lex(); + return true; +} + +bool HloParser::ParseBool(bool* result) { + if (lexer_.GetKind() != TokKind::kw_true && + lexer_.GetKind() != TokKind::kw_false) { + return TokenError("expects true or false"); + } + *result = lexer_.GetKind() == TokKind::kw_true; + lexer_.Lex(); + return true; +} + +bool HloParser::ParseToken(TokKind kind, const string& msg) { + VLOG(1) << "ParseToken " << TokKindToString(kind) << " " << msg; + if (lexer_.GetKind() != kind) { + return TokenError(msg); + } + lexer_.Lex(); + return true; +} + +bool HloParser::EatIfPresent(TokKind kind) { + if (lexer_.GetKind() != kind) { + return false; + } + lexer_.Lex(); + return true; +} + +bool HloParser::AddInstruction(const string& name, + HloInstruction* instruction) { + auto result = instruction_pool_.insert({name, instruction}); + if (!result.second) { + return TokenError(StrCat("instruction already exists: ", name)); + } + return true; +} + +bool HloParser::AddComputation(const string& name, + HloComputation* computation) { + auto result = computation_pool_.insert({name, computation}); + if (!result.second) { + return TokenError(StrCat("computation already exists: ", name)); + } + return true; +} + +} // namespace + +StatusOr> Parse(StringPiece str, + const HloModuleConfig& config) { + HloParser parser(str, config); + if (!parser.Run()) { + return InvalidArgument("Syntax error: %s", parser.GetError().c_str()); + } + return parser.ConsumeHloModule(); +} + +StatusOr> Parse(StringPiece str) { + HloModuleConfig config; + return Parse(str, config); +} + +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.h b/tensorflow/compiler/xla/tools/parser/hlo_parser.h new file mode 100644 index 0000000000000000000000000000000000000000..2f97a2b9b19d0cdb64a2869913da62c55e14c1d5 --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ + +#include "tensorflow/compiler/xla/ptr_util.h" +#include "tensorflow/compiler/xla/service/hlo_computation.h" +#include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_module.h" +#include "tensorflow/compiler/xla/statusor.h" +#include "tensorflow/compiler/xla/tools/parser/hlo_lexer.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" + +namespace xla { +namespace tools { + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, parses the string and creates a HloModule with the given config. +StatusOr> Parse(tensorflow::StringPiece str, + const HloModuleConfig& config); + +// The api of the hlo parser. Given a string in the HloModule::ToString() +// format, parses the string and creates a HloModule with default config. +StatusOr> Parse(tensorflow::StringPiece str); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_PARSER_H_ diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..359256f0646367f8af13439b30067624defcd44c --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -0,0 +1,432 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" + +#include +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" + +namespace xla { +namespace tools { +namespace { + +using tensorflow::StringPiece; + +struct TestData { + string test_name; + string module_string; +}; + +string TestDataToString(const ::testing::TestParamInfo& data) { + return data.param.test_name; +} + +std::vector CreateTestCases() { + // clang-format off + return std::vector({ +// ax + y +{ +"AxpyParam", +R"(HloModule axpy_module: + +ENTRY %axpy.v5 (alpha: f32[2,4], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { + %alpha = f32[2,4]{1,0} parameter(0) + %x = f32[2,4]{1,0} parameter(1) + %multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %alpha, f32[2,4]{1,0} %x) + %y = f32[2,4]{1,0} parameter(2) + ROOT %add = f32[2,4]{1,0} add(f32[2,4]{1,0} %multiply, f32[2,4]{1,0} %y) +} + +)" +}, +// pred constant +{ +"ConstantPred", +R"(HloModule constant_pred_module: + +ENTRY %constant_pred () -> pred[] { + ROOT %constant = pred[] constant(true) +} + +)" +}, +// s32 constant +{ +"ConstantS32", +R"(HloModule constant_s32_module: + +ENTRY %constant_s32 () -> s32[] { + ROOT %constant = s32[] constant(-42) +} + +)" +}, +// f32 constant, but the value is not a decimal +{ +"ConstantF32", R"(HloModule ConstantF32_module: + +ENTRY %ConstantF32.v4 () -> f32[] { + ROOT %constant = f32[] constant(42) +} + +)" +}, +// constant 4D +{ +"Constant4D", +R"(HloModule Small_3x2x1x1_module: + +ENTRY %Small_3x2x1x1.v1 () -> f32[3,2,1,1] { + ROOT %constant = f32[3,2,1,1]{3,2,1,0} constant(f32[3,2,1,1] { { /*i0=0*/ { /*i1=0*/ {-1} }, { /*i1=1*/ {4.1} } }, { /*i0=1*/ { /*i1=0*/ {2} }, { /*i1=1*/ {4.1} } }, { /*i0=2*/ { /*i1=0*/ {5} }, { /*i1=1*/ {4.4} } } }) +} + +)" +}, +// non-finite constants: nan, inf, -inf +{ +"ConstantNonFinite", +R"(HloModule IsFiniteR1F32s_module: + +ENTRY %IsFiniteR1F32s.v2 () -> pred[6] { + %constant = f32[6]{0} constant({nan, 7, nan, -1, inf, -inf}) + ROOT %is-finite = pred[6]{0} is-finite(f32[6]{0} %constant) +} + +)" +}, +// constant f16 +{ +"ConstantF16", +R"(HloModule ConstantF16_module: + +ENTRY %ConstantF16.v4 () -> f16[] { + ROOT %constant = f16[] constant(500) +} + +)" +}, +// constant + constant +{ +"AddConstants", +R"(HloModule add_constants_module: + +ENTRY %add_constants () -> f32[] { + %constant = f32[] constant(3.14) + ROOT %add = f32[] add(f32[] %constant, f32[] %constant) +} + +)" +}, +// tuple constant +{ +"TupleConstant", +R"(HloModule TupleConstant_module: + +ENTRY %TupleConstant.v1 () -> (f32[2,1], f32[2]) { + ROOT %constant = (f32[2,1]{1,0}, f32[2]{0}) constant((f32[2,1], f32[2]) ( f32[2,1] { { 1 }, { 2 } }, {2, 42} )) +} + +)" +}, +// v1 > v2 ? v1 : v2 +{ +"SelectR1F32", +R"(HloModule SelectR1F32WithCmpR1F32sFromParamsSmall_module: + +ENTRY %SelectR1F32WithCmpR1F32sFromParamsSmall.v4 (v1: f32[4], v2: f32[4]) -> f32[4] { + %v1 = f32[4]{0} parameter(0), sharding={maximal device=1} + %v2 = f32[4]{0} parameter(1), sharding={maximal device=1} + %greater-than = pred[4]{0} greater-than(f32[4]{0} %v1, f32[4]{0} %v2), sharding={replicated} + ROOT %select = f32[4]{0} select(pred[4]{0} %greater-than, f32[4]{0} %v1, f32[4]{0} %v2) +} + +)" +}, +// empty tuple +{ +"EmptyTupleCreate", +R"(HloModule EmptyTupleCreate_module: + +ENTRY %EmptyTupleCreate.v1 () -> () { + ROOT %tuple = () tuple() +} + +)" +}, +// tuple +{ +"TupleCreate", +R"(HloModule TupleCreate_module: + +ENTRY %TupleCreate.v4 (v1: f32[], v2: f32[3], v3: f32[2,3]) -> (f32[], f32[3], f32[2,3]) { + %v1 = f32[] parameter(0) + %v2 = f32[3]{0} parameter(1) + %v3 = f32[2,3]{1,0} parameter(2) + ROOT %tuple = (f32[], f32[3]{0}, f32[2,3]{1,0}) tuple(f32[] %v1, f32[3]{0} %v2, f32[2,3]{1,0} %v3) +} + +)" +}, +// int32 result = 0; +// while (result < 5) { result = result + 1; } +{ +"WhileWithScalarS32Result", +R"(HloModule WhileWithScalarS32Result_module: + +%body.v3 (prev.1: s32[]) -> s32[] { + %constant = s32[] constant(1) + %prev.1 = s32[] parameter(0) + ROOT %add = s32[] add(s32[] %constant, s32[] %prev.1) +} + +%condition.v3 (prev.2: s32[]) -> pred[] { + %constant.1 = s32[] constant(5) + %prev.2 = s32[] parameter(0) + ROOT %greater-than = pred[] greater-than(s32[] %constant.1, s32[] %prev.2) +} + +ENTRY %WhileWithScalarS32Result.v2 () -> s32[] { + %constant.2 = s32[] constant(0) + ROOT %while = s32[] while(s32[] %constant.2), condition=%condition.v3, body=%body.v3 +} + +)" +}, +// send and recv +{ +"SendRecv", +R"(HloModule TwoSendRecvBothWayRecvFist_module: + +ENTRY %TwoSendRecvBothWayRecvFist.v3 () -> f32[] { + %recv = f32[] recv(), channel_id=15, sharding={maximal device=1} + ROOT %constant = f32[] constant(2.1), sharding={maximal device=0} + %send = () send(f32[] %constant), channel_id=16, sharding={maximal device=0}, control-predecessors={%recv} +} + +)" +}, +// get-tuple-element +{ +"GetTupleElement", +R"(HloModule GetTupleElement_module: + +ENTRY %GetTupleElement.v4 () -> s32[2,3] { + %constant = f32[3]{0} constant({1, 2, 3}) + %constant.1 = s32[2,3]{1,0} constant(s32[2,3] { { 1, 2, 3 }, { 4, 5, 6 } }) + %tuple = (f32[3]{0}, s32[2,3]{1,0}) tuple(f32[3]{0} %constant, s32[2,3]{1,0} %constant.1) + ROOT %get-tuple-element = s32[2,3]{1,0} get-tuple-element((f32[3]{0}, s32[2,3]{1,0}) %tuple), index=1, sharding={maximal device=0} +} + +)" +}, +// call +{ +"Call", +R"(HloModule CallR0F32IdentityScalar_module: + +%Identity.v1 (x: f32[]) -> f32[] { + ROOT %x = f32[] parameter(0) +} + +ENTRY %CallR0F32IdentityScalar.v2 () -> f32[] { + %constant = f32[] constant(42) + ROOT %call = f32[] call(f32[] %constant), to_apply=%Identity.v1 +} + +)" +} + }); + // clang-format on +} + +class HloParserTest : public ::testing::Test, + public ::testing::WithParamInterface { + protected: + static void ExpectHasSubstr(StringPiece s, StringPiece expected) { + EXPECT_TRUE(StringPiece(s).contains(expected)) + << "'" << s << "' does not contain '" << expected << "'"; + } + + void ExpectSuccess() { + const string& original = GetParam().module_string; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + EXPECT_EQ(original, + result.ValueOrDie()->ToString(/*include_large_constants=*/true)); + } +}; + +TEST_P(HloParserTest, Run) { ExpectSuccess(); } + +INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest, + ::testing::ValuesIn(CreateTestCases()), + TestDataToString); + +TEST_F(HloParserTest, Empty) { + const string original = ""; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, Garbage) { + const string original = "HloModule thi$ str1ng makes# N0 sen$e @all!*&^%$"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOpcode) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[], y: f32[]) -> f32[] { + %x = f32[]{} parameter(0) + %y = f32[]{} parameter(1) + %le = pred[]{} le(f32[]{} %x, f32[]{} %y) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongShape) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: g32[]) -> g32[] { + %x = g32[]{} parameter(0) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, WrongOperandsSize) { + const string original = R"(HloModule wrong_opcode: + +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, OperandNotFound) { + const string original = R"(HloModule operand_not_found: +ENTRY %blabla (x: f32[]) -> pred[] { + %x = f32[]{} parameter(0) + %eq = pred[]{} equal-to(f32[]{} %x, f32[]{} %y) +} +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); +} + +TEST_F(HloParserTest, MoreConstants) { + const string original = R"(HloModule SelectScalarS32True_module: + +ENTRY %SelectScalarS32True.v4 () -> s32[] { + %constant.2 = pred[] constant(true) + %constant.1 = s32[] constant(-42), sharding={s32[5,6] devices=[2,3]1,2,3,4} + %constant = s32[] constant(42) + %select = s32[] select(pred[] %constant.2, s32[] %constant.1, s32[] %constant) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // Constant instructions have no name. The string will be parsed successfully + // but the constant names will not be exactly the same. +} + +TEST_F(HloParserTest, LiteralDimensionsMismatch_1) { + const string original = R"(HloModule some_2_module: + +ENTRY %some_2 () -> f32[2] { + ROOT %constant = f32[2]{0} constant({1,{2}}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects nested array in rank 1, but sees larger"); +} + +TEST_F(HloParserTest, LiteralDimensionsMismatch_2) { + const string original = R"(HloModule some_2x3_module: + +ENTRY %some_2x3 () -> f32[2,3] { + ROOT %constant = f32[2,3]{1,0} constant(f32[2,3] {1, 2, 3, 4, 5, 6}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects nested array in rank 2, but sees 1"); +} + +TEST_F(HloParserTest, LiteralDimensionsMismatch_3) { + const string original = R"(HloModule some_2x3x2_module: + +ENTRY %some_2x3x2 () -> f32[2,3,2] { + ROOT %constant = f32[2,3,2]{2,1,0} constant(f32[2,3,2] {{{1, 2}, {3, 4}, {5, 6}, {7, 8}, {9, 10}, {11, 12}}}) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "expects 3 elements in the [0]th element"); +} + +TEST_F(HloParserTest, ConstantF16Overflow) { + const string original = + R"(HloModule ConstantF16Overflow_module: + +ENTRY %ConstantF16Overflow.v4 () -> f16[] { + ROOT %constant = f16[] constant(-65505) +} + +)"; + auto result = Parse(original); + EXPECT_NE(tensorflow::Status::OK(), result.status()); + ExpectHasSubstr(result.status().error_message(), + "is out of range for literal's primitive type F16"); +} + +TEST_F(HloParserTest, ConstantWithExp) { + const string original = R"(HloModule ConstantWithExp_module: + +ENTRY %ConstantWithExp.v4 () -> f32[] { + %constant.1 = f32[] constant(3e+2) +} + +)"; + auto result = Parse(original); + TF_EXPECT_OK(result.status()); + // The string will be parsed successfully but the output strings are not + // exactly the same, because "3e2" is parsed into value 300 and will be + // printed as "300". +} + +} // namespace +} // namespace tools +} // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h new file mode 100644 index 0000000000000000000000000000000000000000..9c2069e7568e46e89afc0fd43d0ff3d8492991fb --- /dev/null +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -0,0 +1,71 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ +#define TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ + +#include + +namespace xla { +namespace tools { + +// Defines different kinds of tokens in a hlo module string. +enum class TokKind { + // Markers + kEof, + kError, + + // Tokens with no info. + kEqual, // = + kComma, // , + kColon, // : + kLsquare, + kRsquare, // [ ] + kLbrace, + kRbrace, // { } + kLparen, + kRparen, // ( ) + + kArrow, // -> + kComment, // /*xxx*/ + + // Keywords + kw_HloModule, + kw_ENTRY, + kw_ROOT, + kw_true, + kw_false, + kw_maximal, + kw_replicated, + kw_nan, + kw_inf, + + kNegInf, // -inf + + // Typed tokens. + kName, // %foo + kAttributeName, // dimensions= + kShape, // f32[2,3]{1,0} + kOpcode, // add + kInt, // 42 + kDecimal, // 4.2 +}; + +string TokKindToString(TokKind kind); + +} // namespace tools +} // namespace xla + +#endif // TENSORFLOW_COMPILER_XLA_TOOLS_PARSER_HLO_TOKEN_H_ diff --git a/tensorflow/compiler/xla/types.h b/tensorflow/compiler/xla/types.h index ea8b4b7b989b72034f33920a7d8c1a75e15a7dd1..3b19ca321cad35aad18f7f498e08fd744ffbc371 100644 --- a/tensorflow/compiler/xla/types.h +++ b/tensorflow/compiler/xla/types.h @@ -16,6 +16,8 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_XLA_TYPES_H_ #define TENSORFLOW_COMPILER_XLA_TYPES_H_ +#include + #include "third_party/eigen3/Eigen/Core" #include "tensorflow/core/platform/types.h" @@ -35,6 +37,8 @@ using ::tensorflow::uint16; using ::tensorflow::uint32; using ::tensorflow::uint64; +using complex64 = std::complex; + using ::Eigen::half; } // namespace xla diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc index 1c7361105595ca25cd130dbb890b9e2cb694a7ac..2624ef0252fd9482a600fe3aec07f7f328a86d69 100644 --- a/tensorflow/compiler/xla/util.cc +++ b/tensorflow/compiler/xla/util.cc @@ -336,4 +336,13 @@ std::vector> CommonFactors( return bounds; } +string SanitizeFileName(string file_name) { + for (char& c : file_name) { + if (c == '/' || c == '\\' || c == '[' || c == ']') { + c = '_'; + } + } + return file_name; +} + } // namespace xla diff --git a/tensorflow/compiler/xla/util.h b/tensorflow/compiler/xla/util.h index 1a54c4029c8586099f26fa3cdd7fdcaf1d083dfa..f58f57b44396c90a3820835a3d0ecc792aaa7cd0 100644 --- a/tensorflow/compiler/xla/util.h +++ b/tensorflow/compiler/xla/util.h @@ -24,6 +24,7 @@ limitations under the License. #include #include "tensorflow/compiler/xla/status.h" +#include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/status.h" @@ -361,6 +362,9 @@ int64 Product(tensorflow::gtl::ArraySlice xs); std::vector> CommonFactors( tensorflow::gtl::ArraySlice a, tensorflow::gtl::ArraySlice b); +// Removes illegal characters from filenames. +string SanitizeFileName(string file_name); + } // namespace xla #define XLA_LOG_LINES(SEV, STRING) \ diff --git a/tensorflow/compiler/xla/util_test.cc b/tensorflow/compiler/xla/util_test.cc index 547b924180bf59091ebd552618bf6bd5be9cd6a7..288479c893855742f7aa76fab532c5ca8f942e3c 100644 --- a/tensorflow/compiler/xla/util_test.cc +++ b/tensorflow/compiler/xla/util_test.cc @@ -122,5 +122,12 @@ TEST(UtilTest, CommonFactors) { } } +TEST(UtilTest, SanitizeFileName) { + EXPECT_EQ(SanitizeFileName(""), ""); + EXPECT_EQ(SanitizeFileName("abc"), "abc"); + EXPECT_EQ(SanitizeFileName("/\\[]"), "____"); + EXPECT_EQ(SanitizeFileName("/A\\B[C]"), "_A_B_C_"); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/xla.bzl b/tensorflow/compiler/xla/xla.bzl index 22e70ec97adf9297ceb3f98f57feb17ae9dafc3d..3fa5bcc1df4f0294582b6c74735fef08c87433eb 100644 --- a/tensorflow/compiler/xla/xla.bzl +++ b/tensorflow/compiler/xla/xla.bzl @@ -17,11 +17,3 @@ def xla_proto_library(name, srcs=[], deps=[], visibility=None, testonly=0): protoc="@protobuf_archive//:protoc", testonly=testonly, visibility=visibility,) - -# Flags required for modules that export symbols that are to be called by the -# XLA CustomCall operator. CustomCall must be able to find symbols with dlsym(), -# which on Linux requires we link with --export-dynamic. -export_dynamic_linkopts = select({ - "//tensorflow:darwin": [], - "//conditions:default": ["-Wl,--export-dynamic"], -}) diff --git a/tensorflow/compiler/xla/xla.proto b/tensorflow/compiler/xla/xla.proto index 4840ddb8817a37c7dabcfb27e24a2a5472f4b6a2..710bb6ff25bf649693165c5e9fb6bc50e81db4ca 100644 --- a/tensorflow/compiler/xla/xla.proto +++ b/tensorflow/compiler/xla/xla.proto @@ -82,8 +82,8 @@ message DebugOptions { // Dump all HLO modules as text into the provided directory path. string xla_generate_hlo_text_to = 7; - // Dump compilation artifacts as JSON into this directory. - string xla_dump_debug_json_to = 8; + // Dump compilation artifacts in binary proto into this directory. + string xla_dump_hlo_proto_to = 8; // Instrument the computation to collect per-HLO cycle counts. bool xla_hlo_profile = 9; @@ -191,6 +191,11 @@ message ExecutionOptions { uint64 seed = 3; DebugOptions debug_options = 4; + + // This optional field specifies a particular set of devices to run the + // computation on. The computation will be partitioned across these devices. + // If not provided, the default device will be chosen. + repeated DeviceHandle device_handles = 5; } message SnapshotComputationRequest { @@ -312,12 +317,8 @@ message ExecuteRequest { ComputationHandle computation = 1; repeated GlobalDataHandle arguments = 2; - // This optional field specifies a particular device to run the computation. - // If not provided, the default device will be chosen. - DeviceHandle device_handle = 5; - // Options that affect how XLA compiles and runs code to service this request. - ExecutionOptions execution_options = 6; + ExecutionOptions execution_options = 5; } message ExecuteParallelRequest { @@ -360,6 +361,7 @@ message WaitForExecutionResponse { message IsConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; + int64 num_parameters = 3; } message IsConstantResponse { @@ -370,6 +372,7 @@ message ComputeConstantRequest { ComputationHandle computation = 1; ComputationDataHandle operand = 2; Layout output_layout = 3; + repeated LiteralProto parameters = 4; } message ComputeConstantResponse { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3327e06ed8c7602903b501ff0dbaf16a6c97a82b..06987e0044d7f69637c9ca0e1a2b40d91cd74713 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -48,6 +48,9 @@ enum PrimitiveType { F32 = 11; F64 = 12; + // Complex values of fixed width. + C64 = 15; // Paired F32 (real, imag), as in std::complex. + // A tuple is a polymorphic sequence; e.g. a shape that holds different // sub-shapes. They are used for things like returning multiple values from a // computation; e.g. a computation that returns weights and biases may have a @@ -305,6 +308,7 @@ message LiteralProto { repeated uint64 u64s = 7; repeated float f32s = 8; repeated double f64s = 9; + repeated float c64s = 12; // Stored as interleaved real, imag floats. repeated LiteralProto tuple_literals = 10; bytes f16s = 11; // Note: the F16s are encoded in little endian byte order } @@ -392,13 +396,17 @@ message DynamicUpdateSliceRequest { } message ConvolutionDimensionNumbers { - // The number of the dimension that represents batch in the input - // (lhs) and output. - int64 batch_dimension = 1; + // The number of the dimension that represents batch in the input. + int64 input_batch_dimension = 7; + + // The number of the dimension that represents features in the input. + int64 input_feature_dimension = 8; + + // The number of the dimension that represents batch in the output. + int64 output_batch_dimension = 9; - // The number of the dimension that represents features in the input - // (lhs) and output. - int64 feature_dimension = 2; + // The number of the dimension that represents features in the output. + int64 output_feature_dimension = 10; // The dimension numbers for the spatial dimensions that the window // moves through in the input (lhs) and output. @@ -425,6 +433,20 @@ message ConvolveRequest { ConvolutionDimensionNumbers dimension_numbers = 5; } +enum FftType { + FFT = 0; // Forward FFT; complex in, complex out. + IFFT = 1; // Inverse FFT; complex in, complex out. + RFFT = 2; // Forward real FFT; real in, fft_length / 2 + 1 complex out + IRFFT = 3; // Inverse real FFT; fft_length / 2 + 1 complex in, + // fft_length real out +} + +message FftRequest { + FftType fft_type = 1; + repeated int64 fft_length = 2; // Multivalent for higher-order FFT. + ComputationDataHandle operand = 3; +} + message InfeedRequest { // The shape of the data returned by reading the device's infeed buffer. Shape shape = 2; @@ -459,6 +481,11 @@ message MapRequest { repeated ComputationDataHandle operands = 2; ComputationHandle to_apply = 3; repeated ComputationDataHandle static_operands = 4; + // The dimensions over which to map. + // Example mapping a Dot operation along the batch dimension 0: + // operand0.shape = [2, 2, 2], operand1.shape = [2,2,3] + // Map({operand0, operand1}, Dot, {0}) + repeated int64 dimensions = 5; } message ReduceRequest { @@ -612,8 +639,8 @@ message WhileRequest { enum UnaryOperation { UNOP_INVALID = 0; - // Elementwise, logical negation - UNOP_LOGICAL_NOT = 1; + // Elementwise, logical negation on booleans and bitwise negation on ints. + UNOP_NOT = 1; // Elementwise, computes e^x. UNOP_EXP = 2; @@ -654,6 +681,12 @@ enum UnaryOperation { // Elementwise, rounds x to nearest integral value, rounding half-way cases // away from zero. UNOP_ROUND_NEAREST_AFZ = 14; + + // Elementwise, extract real component of complex x. + UNOP_REAL = 15; + + // Elementwise, extract real component of complex x. + UNOP_IMAG = 16; } message UnaryOpRequest { @@ -681,14 +714,6 @@ enum BinaryOperation { // Dot product, matrix multiply. BINOP_DOT = 12; - // Indexes into the LHS with the RHS. - // - // If the RHS is higher-rank, this is a gather operation. - // - // Note: currently out of bounds indices may crash the underlying XLA - // machine. - BINOP_INDEX = 13; - // Element-wise maximum. BINOP_MAX = 14; @@ -701,9 +726,19 @@ enum BinaryOperation { // Remainder operation. BINOP_REM = 17; - // Logical operators - BINOP_LOGICAL_AND = 18; - BINOP_LOGICAL_OR = 19; + // Element-wise, logical operators on booleans and bitwise operators on ints. + BINOP_AND = 18; + BINOP_OR = 19; + + BINOP_SHIFT_LEFT = 20; + BINOP_SHIFT_RIGHT_ARITHMETIC = 21; + BINOP_SHIFT_RIGHT_LOGICAL = 22; + + // Complex from real, imag. + BINOP_COMPLEX = 23; + + // Computes the 4-quadrant arctangent of the y, x input arguments. + BINOP_ATAN2 = 24; } message BinaryOpRequest { @@ -742,10 +777,6 @@ enum TernaryOperation { // true and operand1 if the predicate is false. TRIOP_SELECT = 1; - // Updates operand0 at index operand1 with value operand2 and outputs the - // updated value. - TRIOP_UPDATE = 2; - // Given a min, max and an operand returns the operand if between min and max, // else returns min if operand is less than min or max if operand is greater // than max. @@ -787,18 +818,32 @@ message RecvRequest { ChannelHandle channel_handle = 2; } -message OpDeviceAssignment { - bool has_device = 1; - - // Number of the device to which this operator is assigned. Ignored if - // 'has_device' is false. - int32 device = 2; +message OpSharding { + enum Type { + // This sharding is replicated across all devices (implies maximal, + // all other fields are unused). + REPLICATED = 0; + // This sharding is maximal - one device runs the entire operation. + MAXIMAL = 1; + // Neither of the above; tile_shape and tile_assignment are both used. + OTHER = 2; + } + Type type = 1; + // The shape of the sharded tile. + Shape tile_shape = 2; + // The shape of the tile assignment tensor - this must be the same rank as + // tile_shape and the product of its dimensions must equal + // tile_assignment_devices.size(). + repeated int64 tile_assignment_dimensions = 3; + // Flattened list of device IDs. The order of flattening is the same as used + // by IndexUtil::MultiToLinearIndex(tile_assignment_shape). + repeated int64 tile_assignment_devices = 4; } message OpRequest { ComputationHandle computation = 1; OpMetadata metadata = 33; - OpDeviceAssignment device_assignment = 39; + OpSharding sharding = 40; oneof op { BinaryOpRequest binary_op_request = 2; @@ -837,7 +882,8 @@ message OpRequest { BatchNormTrainingRequest batch_norm_training_request = 35; BatchNormGradRequest batch_norm_grad_request = 37; BatchNormInferenceRequest batch_norm_inference_request = 38; - // Next: 40 + FftRequest fft_request = 41; + // Next: 42 } } diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index 14fa6ea7cdd544ed9b40a16eef9e89cc9b305eff..3d53cbba5652c902855972f6e4e3ee78a3e1bcc7 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -43,6 +43,7 @@ py_library( "//tensorflow/contrib/integrate:integrate_py", "//tensorflow/contrib/keras", "//tensorflow/contrib/kernel_methods", + "//tensorflow/contrib/kfac", "//tensorflow/contrib/labeled_tensor", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", @@ -52,9 +53,11 @@ py_library( "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/losses:losses_py", + "//tensorflow/contrib/losses:metric_learning_py", "//tensorflow/contrib/memory_stats:memory_stats_py", "//tensorflow/contrib/meta_graph_transform", "//tensorflow/contrib/metrics:metrics_py", + "//tensorflow/contrib/model_pruning", "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/contrib/ndlstm", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_py", @@ -62,6 +65,7 @@ py_library( "//tensorflow/contrib/opt:opt_py", "//tensorflow/contrib/predictor", "//tensorflow/contrib/quantization:quantization_py", + "//tensorflow/contrib/quantize:quantize_graph", "//tensorflow/contrib/reduce_slice_ops:reduce_slice_ops_py", "//tensorflow/contrib/remote_fused_graph/pylib:remote_fused_graph_ops_py", "//tensorflow/contrib/resampler:resampler_py", @@ -77,7 +81,7 @@ py_library( "//tensorflow/contrib/staging", "//tensorflow/contrib/stat_summarizer:stat_summarizer_py", "//tensorflow/contrib/stateless", - "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/contrib/summary:summary", "//tensorflow/contrib/tensor_forest:init_py", "//tensorflow/contrib/tensorboard", "//tensorflow/contrib/testing:testing_py", @@ -85,8 +89,10 @@ py_library( "//tensorflow/contrib/tfprof", "//tensorflow/contrib/timeseries", "//tensorflow/contrib/tpu", + "//tensorflow/contrib/tpu:tpu_py", "//tensorflow/contrib/training:training_py", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]), ) @@ -102,6 +108,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_kernel", "//tensorflow/contrib/nccl:nccl_kernels", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_kernels", + "//tensorflow/contrib/rnn:all_kernels", "//tensorflow/contrib/seq2seq:beam_search_ops_kernels", "//tensorflow/contrib/tensor_forest:model_ops_kernels", "//tensorflow/contrib/tensor_forest:stats_ops_kernels", @@ -123,6 +130,7 @@ cc_library( "//tensorflow/contrib/layers:sparse_feature_cross_op_op_lib", "//tensorflow/contrib/nccl:nccl_ops_op_lib", "//tensorflow/contrib/nearest_neighbor:nearest_neighbor_ops_op_lib", + "//tensorflow/contrib/rnn:all_ops", "//tensorflow/contrib/seq2seq:beam_search_ops_op_lib", "//tensorflow/contrib/tensor_forest:model_ops_op_lib", "//tensorflow/contrib/tensor_forest:stats_ops_op_lib", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 5b3f0b3f6eee6c49a85ff6e3654e390da64ab762..3068e9ed8f53e3e0f7cbf2d0222121a5752a2a56 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -40,6 +40,7 @@ from tensorflow.contrib import input_pipeline from tensorflow.contrib import integrate from tensorflow.contrib import keras from tensorflow.contrib import kernel_methods +from tensorflow.contrib import kfac from tensorflow.contrib import labeled_tensor from tensorflow.contrib import layers from tensorflow.contrib import learn @@ -50,11 +51,13 @@ from tensorflow.contrib import lookup from tensorflow.contrib import losses from tensorflow.contrib import memory_stats from tensorflow.contrib import metrics +from tensorflow.contrib import model_pruning from tensorflow.contrib import nccl from tensorflow.contrib import nn from tensorflow.contrib import opt from tensorflow.contrib import predictor from tensorflow.contrib import quantization +from tensorflow.contrib import quantize from tensorflow.contrib import reduce_slice_ops from tensorflow.contrib import resampler from tensorflow.contrib import rnn @@ -75,9 +78,11 @@ from tensorflow.contrib import timeseries from tensorflow.contrib import tpu from tensorflow.contrib import training from tensorflow.contrib import util +from tensorflow.contrib.eager.python import tfe as eager from tensorflow.contrib.ndlstm import python as ndlstm from tensorflow.contrib.remote_fused_graph import pylib as remote_fused_graph from tensorflow.contrib.specs import python as specs +from tensorflow.contrib.summary import summary from tensorflow.python.util.lazy_loader import LazyLoader ffmpeg = LazyLoader("ffmpeg", diff --git a/tensorflow/contrib/all_reduce/BUILD b/tensorflow/contrib/all_reduce/BUILD index 744ae4c1f413bc1854a07ead9a3fa6bc90ed2fc1..8dff93b4f825277dcf0a64aa3b96bd809d36e1e9 100644 --- a/tensorflow/contrib/all_reduce/BUILD +++ b/tensorflow/contrib/all_reduce/BUILD @@ -19,9 +19,10 @@ py_library( srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - "//tensorflow/contrib/nccl:nccl_ops", + "//tensorflow/contrib/nccl:nccl_py", "//tensorflow/python:array_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", ], ) @@ -31,12 +32,17 @@ tf_py_test( additional_deps = [ ":all_reduce", "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:constant_op", "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//tensorflow/python:state_ops", ], ) diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py index 8e7f1791b864bb30e4592a86e637d1603be6618b..a5057da9fd43a88575813613d6ac9d17fd2b2e28 100644 --- a/tensorflow/contrib/all_reduce/python/all_reduce.py +++ b/tensorflow/contrib/all_reduce/python/all_reduce.py @@ -191,7 +191,7 @@ def _ragged_split(tensor, pieces): def _ring_permutations(num_workers, num_subchunks, gpu_perm): - """"Generate an array of device index arrays, one for for each subchunk. + """"Generate an array of device index arrays, one for each subchunk. In the basic ring reduction algorithm there are size(T)/num_devices data chunks and each device process one chunk per tick, i.e. sending @@ -762,6 +762,8 @@ def _reduce_non_singleton(input_tensors, red_f, un_op): if len(input_tensors) > 1: return red_f(input_tensors) else: + if not un_op: + return input_tensors output_tensors = [] for t in input_tensors: with ops.colocate_with(t): @@ -835,7 +837,7 @@ def _build_shuffle_hybrid(input_tensors, gather_devices, red_op, upper_level_f): def build_shuffle_then_ring(input_tensors, gather_devices, subdiv, - red_n_op, red_op, un_op): + red_n_op, red_op, un_op=None): """Construct hybrid of Shuffle within workers, Ring across workers.""" def upper_builder(tensors): return build_ring_all_reduce(tensors, len(tensors), subdiv, [0], diff --git a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java index 395dd6c5d2ffdea058dc95e878685895c37f0b9c..1f423a7a5bf6a115dc627ddd6f5e98c074282585 100644 --- a/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java +++ b/tensorflow/contrib/android/java/org/tensorflow/contrib/android/TensorFlowInferenceInterface.java @@ -31,12 +31,13 @@ import java.nio.IntBuffer; import java.nio.LongBuffer; import java.util.ArrayList; import java.util.List; -import org.tensorflow.DataType; import org.tensorflow.Graph; import org.tensorflow.Operation; import org.tensorflow.Session; import org.tensorflow.Tensor; import org.tensorflow.TensorFlow; +import org.tensorflow.Tensors; +import org.tensorflow.types.UInt8; /** * Wrapper over the TensorFlow API ({@link Graph}, {@link Session}) providing a smaller API surface @@ -281,6 +282,22 @@ public class TensorFlowInferenceInterface { // Methods for taking a native Tensor and filling it with values from Java arrays. + /** + * Given a source array with shape {@link dims} and content {@link src}, copy the contents into + * the input Tensor with name {@link inputName}. The source array {@link src} must have at least + * as many elements as that of the destination Tensor. If {@link src} has more elements than the + * destination has capacity, the copy is truncated. + */ + public void feed(String inputName, boolean[] src, long... dims) { + byte[] b = new byte[src.length]; + + for (int i = 0; i < src.length; i++) { + b[i] = src[i] ? (byte) 1 : (byte) 0; + } + + addFeed(inputName, Tensor.create(Boolean.class, dims, ByteBuffer.wrap(b))); + } + /** * Given a source array with shape {@link dims} and content {@link src}, copy the contents into * the input Tensor with name {@link inputName}. The source array {@link src} must have at least @@ -328,7 +345,7 @@ public class TensorFlowInferenceInterface { * destination has capacity, the copy is truncated. */ public void feed(String inputName, byte[] src, long... dims) { - addFeed(inputName, Tensor.create(DataType.UINT8, dims, ByteBuffer.wrap(src))); + addFeed(inputName, Tensor.create(UInt8.class, dims, ByteBuffer.wrap(src))); } /** @@ -337,7 +354,7 @@ public class TensorFlowInferenceInterface { * a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } /** @@ -346,7 +363,7 @@ public class TensorFlowInferenceInterface { * arbitrary sequence of bytes, not a Java {@code String} (which is a sequence of characters). */ public void feedString(String inputName, byte[][] src) { - addFeed(inputName, Tensor.create(src)); + addFeed(inputName, Tensors.create(src)); } // Methods for taking a native Tensor and filling it with src from Java native IO buffers. @@ -403,7 +420,7 @@ public class TensorFlowInferenceInterface { * destination has capacity, the copy is truncated. */ public void feed(String inputName, ByteBuffer src, long... dims) { - addFeed(inputName, Tensor.create(DataType.UINT8, dims, src)); + addFeed(inputName, Tensor.create(UInt8.class, dims, src)); } /** @@ -544,7 +561,7 @@ public class TensorFlowInferenceInterface { "Model load took " + (endMs - startMs) + "ms, TensorFlow version: " + TensorFlow.version()); } - private void addFeed(String inputName, Tensor t) { + private void addFeed(String inputName, Tensor t) { // The string format accepted by TensorFlowInferenceInterface is node_name[:output_index]. TensorId tid = TensorId.parse(inputName); runner.feed(tid.name, tid.outputIndex, t); @@ -578,7 +595,7 @@ public class TensorFlowInferenceInterface { } } - private Tensor getTensor(String outputName) { + private Tensor getTensor(String outputName) { int i = 0; for (String n : fetchNames) { if (n.equals(outputName)) { @@ -591,7 +608,7 @@ public class TensorFlowInferenceInterface { } private void closeFeeds() { - for (Tensor t : feedTensors) { + for (Tensor t : feedTensors) { t.close(); } feedTensors.clear(); @@ -599,7 +616,7 @@ public class TensorFlowInferenceInterface { } private void closeFetches() { - for (Tensor t : fetchTensors) { + for (Tensor t : fetchTensors) { t.close(); } fetchTensors.clear(); @@ -614,9 +631,9 @@ public class TensorFlowInferenceInterface { // State reset on every call to run. private Session.Runner runner; private List feedNames = new ArrayList(); - private List feedTensors = new ArrayList(); + private List> feedTensors = new ArrayList>(); private List fetchNames = new ArrayList(); - private List fetchTensors = new ArrayList(); + private List> fetchTensors = new ArrayList>(); // Mutable state. private RunStats runStats; diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD index 1555a3427fd5e40ca54c134a2c80f9d2c5feca36..8b7df4a84c558f662405a28a42426583d5ab39cd 100644 --- a/tensorflow/contrib/batching/BUILD +++ b/tensorflow/contrib/batching/BUILD @@ -69,6 +69,28 @@ tf_cc_test( ], ) +cc_library( + name = "adaptive_shared_batch_scheduler", + hdrs = ["adaptive_shared_batch_scheduler.h"], + deps = [ + ":batch_scheduler", + "//tensorflow/contrib/batching/util:periodic_function_dynamic", + "//tensorflow/core:lib", + ], +) + +tf_cc_test( + name = "adaptive_shared_batch_scheduler_test", + srcs = ["adaptive_shared_batch_scheduler_test.cc"], + deps = [ + ":adaptive_shared_batch_scheduler", + "//tensorflow/contrib/batching/test_util:fake_clock_env", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + cc_library( name = "basic_batch_scheduler", hdrs = ["basic_batch_scheduler.h"], @@ -155,14 +177,13 @@ tf_custom_op_py_library( deps = [ ":batch_ops", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", + "//tensorflow/python:gradients", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", + "//tensorflow/python:script_ops", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h new file mode 100644 index 0000000000000000000000000000000000000000..6ed177e001758ad8c566c7965e1ec10ae5235fc8 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h @@ -0,0 +1,462 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/contrib/batching/batch_scheduler.h" +#include "tensorflow/contrib/batching/util/periodic_function.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/cpu_info.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/thread_annotations.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace serving { +namespace internal { +template +class ASBSBatch; + +template +class ASBSQueue; +} // namespace internal + +// Shared batch scheduler designed to minimize latency. The scheduler keeps +// track of a number of queues (one per model or model version) which are +// continuously enqueuing requests. The scheduler groups the requests into +// batches which it periodically sends off for processing (see +// shared_batch_scheduler.h for more details). The AdaptiveSharedBatchScheduler +// prioritizes batches by age (i.e. the batch's oldest request) irrespective of +// queue. The scheduler will process the oldest batch at an adjustable rate, +// regardless of batch size. The user can provide feedback to help set this rate +// to achieve some goal (i.e. minimize overall latency, limit cpu usage, etc). +// +// The rate (or rather, the corresponding period) is adjusted each time a batch +// is processed, using an exponentially weighted moving average to smooth +// potentially noisy feedback: +// ewma_feedback = ((N - 1) * ewma_feedback + feedback()) / N +// period *= (1 + K * emwa_feedback) +// +// Some potential use cases: +// Hardware Accelerators (GPUs & TPUs) - If some phase of batch processing +// involves serial processing by a device, from a latency perspective it is +// desirable to keep the device evenly loaded, avoiding the need to wait for +// the device to process prior batches. +// feedback = num_pending_on_device() - desired_pending. +// CPU utilization - If the batch processing is cpu dominated, you can reap +// latency gains when underutilized by increasing the processing rate, but +// back the rate off when the load increases to avoid overload. +// feedback = cpu_rate() - desired_cpu_rate. + +template +class AdaptiveSharedBatchScheduler + : public std::enable_shared_from_this< + AdaptiveSharedBatchScheduler> { + public: + struct Options { + // The name to use for the pool of batch threads. + string thread_pool_name = {"batch_threads"}; + // Number of batch processing threads; equivalently the maximum number of + // concurrently running batches. + int64 num_batch_threads = port::NumSchedulableCPUs(); + // The environment to use (typically only overridden by test code). + Env* env = Env::Default(); + // Initial batch scheduling period in microseconds. Will be altered for + // non-zero rate_feedback. + double initial_scheduling_period_micros = 500; + // Minimum batch scheduling period in microseconds. Recommend setting this + // value greater than 0, otherwise it may take a while to recover from a + // sustained time of negative scheduling_period_feedback (which may occur + // under low load). + double min_scheduling_period_micros = 100; + // Maximum batch scheduling period in microseconds. + double max_scheduling_period_micros = 10000; + // Feedback function used to modify the scheduling period each time a batch + // is scheduled. Should return values roughly O(1), with positive values + // resulting in an increased period. + std::function scheduling_period_feedback{[] { return 0.; }}; + // To handle potentially noisy scheduling_period_feedback, the period is + // adjusted using an exponentially weighted moving average over the previous + // feedback_smoothing_batches batches. Must be greater than 0. + int64 feedback_smoothing_batches = 10; + }; + + // Ownership is shared between the caller of Create() and any queues created + // via AddQueue(). + static Status Create( + const Options& options, + std::shared_ptr>* scheduler); + + struct QueueOptions { + // Maximum size of each batch. + int max_batch_size = 1000; + // Maximum number of enqueued (i.e. non-scheduled) batches. + int max_enqueued_batches = 10; + }; + + using BatchProcessor = std::function>)>; + + // Adds queue (and its callback) to be managed by this scheduler. + Status AddQueue(const QueueOptions& options, + BatchProcessor process_batch_callback, + std::unique_ptr>* queue); + + private: + // access to AddBatch, RemoveQueue, GetEnv. + friend class internal::ASBSQueue; + + explicit AdaptiveSharedBatchScheduler(const Options& options); + + // Batch scheduling function which runs every scheduling_period_ microseconds. + void ProcessOneBatch(); + + // Notifies scheduler of non-empty batch which is eligible for processing. + void AddBatch(internal::ASBSBatch*); + + // Removes queue from scheduler. + void RemoveQueue(const internal::ASBSQueue* queue); + + Env* GetEnv() const { return options_.env; } + + const Options options_; + + struct BatchCompare { + bool operator()(const internal::ASBSBatch* a, + const internal::ASBSBatch* b); + }; + + // Collection of batches added by AddBatch, ordered by age. Owned by scheduler + // until they are released for processing. + std::priority_queue*, + std::vector*>, BatchCompare> + batches_ GUARDED_BY(mu_); + + // Unowned queues and callbacks added by AddQueue. + std::unordered_map*, BatchProcessor> + queues_and_callbacks_ GUARDED_BY(mu_); + + mutex mu_; + + // Responsible for running ProcessOneBatch. PeriodicFunction was used in order + // to check for deletion so that the thread can be shut down. + std::unique_ptr scheduling_thread_; + + // Responsible for running the batch processing callbacks. + std::unique_ptr batch_thread_pool_; + + // Time interval in microseconds between successive ProcessOneBatch calls. + double scheduling_period_; + + // Exponentially weighted moving average of + // options_.scheduling_period_feedback() evaluated in each ProcessOneBatch + // call. + double ewma_feedback_ = 0; + + TF_DISALLOW_COPY_AND_ASSIGN(AdaptiveSharedBatchScheduler); +}; + +////////////////////////////////////////////////////////// +// Implementation details follow. API users need not read. + +namespace internal { +// Consolidates tasks into batches, passing them off to the +// AdaptiveSharedBatchScheduler for processing. +template +class ASBSQueue : public BatchScheduler { + public: + using QueueOptions = + typename AdaptiveSharedBatchScheduler::QueueOptions; + + ASBSQueue(std::shared_ptr> scheduler, + const QueueOptions& options); + + ~ASBSQueue() override; + + // Adds task to current batch. Fails if the task size is larger than the batch + // size or if the current batch is full and this queue's number of outstanding + // batches is at its maximum. + Status Schedule(std::unique_ptr* task) override; + + // Number of tasks waiting to be scheduled. + size_t NumEnqueuedTasks() const override; + + // Number of size 1 tasks which could currently be scheduled without failing. + size_t SchedulingCapacity() const override; + + // Notifies queue that a batch is about to be scheduled; the queue should not + // place any more tasks in this batch. + void ReleaseBatch(const ASBSBatch* batch); + + private: + std::shared_ptr> scheduler_; + const QueueOptions options_; + // Owned by scheduler_. + ASBSBatch* current_batch_ GUARDED_BY(mu_) = nullptr; + int64 num_enqueued_batches_ GUARDED_BY(mu_) = 0; + int64 num_enqueued_tasks_ GUARDED_BY(mu_) = 0; + mutable mutex mu_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSQueue); +}; + +// Batch which remembers when and by whom it was created. +template +class ASBSBatch : public Batch { + public: + ASBSBatch(ASBSQueue* queue, int64 creation_time_micros) + : queue_(queue), creation_time_micros_(creation_time_micros) {} + + ~ASBSBatch() override {} + + ASBSQueue* queue() const { return queue_; } + + int64 creation_time_micros() const { return creation_time_micros_; } + + private: + ASBSQueue* queue_; + const int64 creation_time_micros_; + TF_DISALLOW_COPY_AND_ASSIGN(ASBSBatch); +}; +} // namespace internal + +// ---------------- AdaptiveSharedBatchScheduler ---------------- + +template +Status AdaptiveSharedBatchScheduler::Create( + const Options& options, + std::shared_ptr>* scheduler) { + if (options.num_batch_threads < 1) { + return errors::InvalidArgument("num_batch_threads must be positive; was ", + options.num_batch_threads); + } + if (options.min_scheduling_period_micros < 0) { + return errors::InvalidArgument( + "min_scheduling_period_micros must be >= 0; was ", + options.min_scheduling_period_micros); + } + if (options.min_scheduling_period_micros > + options.initial_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be >= min_scheduling_period_micros (", + options.min_scheduling_period_micros, ")"); + } + if (options.initial_scheduling_period_micros > + options.max_scheduling_period_micros) { + return errors::InvalidArgument( + "initial_scheduling_period_micros (", + options.initial_scheduling_period_micros, + ") must be <= max_scheduling_period_micros (", + options.max_scheduling_period_micros, ")"); + } + if (options.feedback_smoothing_batches < 1) { + return errors::InvalidArgument( + "feedback_smoothing_batches must be positive; was ", + options.feedback_smoothing_batches); + } + scheduler->reset(new AdaptiveSharedBatchScheduler(options)); + return Status::OK(); +} + +template +AdaptiveSharedBatchScheduler::AdaptiveSharedBatchScheduler( + const Options& options) + : options_(options), + scheduling_period_(options.initial_scheduling_period_micros) { + PeriodicFunction::Options opts; + opts.thread_name_prefix = "scheduling_thread"; + opts.env = GetEnv(); + scheduling_thread_.reset( + new PeriodicFunction([this] { ProcessOneBatch(); }, 0, opts)); + batch_thread_pool_.reset(new thread::ThreadPool( + GetEnv(), options.thread_pool_name, options.num_batch_threads)); +} + +template +Status AdaptiveSharedBatchScheduler::AddQueue( + const QueueOptions& options, BatchProcessor process_batch_callback, + std::unique_ptr>* queue) { + if (options.max_batch_size <= 0) { + return errors::InvalidArgument("max_batch_size must be positive; was ", + options.max_batch_size); + } + if (options.max_enqueued_batches <= 0) { + return errors::InvalidArgument( + "max_enqueued_batches must be positive; was ", + options.max_enqueued_batches); + } + internal::ASBSQueue* asbs_queue_raw; + queue->reset(asbs_queue_raw = new internal::ASBSQueue( + this->shared_from_this(), options)); + mutex_lock l(mu_); + queues_and_callbacks_[asbs_queue_raw] = process_batch_callback; + return Status::OK(); +} + +template +void AdaptiveSharedBatchScheduler::AddBatch( + internal::ASBSBatch* batch) { + mutex_lock l(mu_); + batches_.push(batch); +} + +template +void AdaptiveSharedBatchScheduler::RemoveQueue( + const internal::ASBSQueue* queue) { + mutex_lock l(mu_); + queues_and_callbacks_.erase(queue); +} + +template +void AdaptiveSharedBatchScheduler::ProcessOneBatch() { + static const double kFeedbackMultiplier = .001; + internal::ASBSBatch* batch = nullptr; + BatchProcessor callback; + const int64 start_time_micros = GetEnv()->NowMicros(); + { + mutex_lock l(mu_); + if (!batches_.empty()) { + batch = batches_.top(); + batches_.pop(); + callback = queues_and_callbacks_[batch->queue()]; + } + } + if (batch != nullptr) { + double feedback = options_.scheduling_period_feedback(); + const int64 N = options_.feedback_smoothing_batches; + ewma_feedback_ = ((N - 1) * ewma_feedback_ + feedback) / N; + scheduling_period_ *= (1 + kFeedbackMultiplier * ewma_feedback_); + if (scheduling_period_ < options_.min_scheduling_period_micros) { + scheduling_period_ = options_.min_scheduling_period_micros; + } else if (scheduling_period_ > options_.max_scheduling_period_micros) { + scheduling_period_ = options_.max_scheduling_period_micros; + } + // Queue may destroy itself after ReleaseBatch is called. + batch->queue()->ReleaseBatch(batch); + batch_thread_pool_->Schedule([callback, batch] { + callback(std::unique_ptr>(batch)); + }); + } + const int64 sleep_time = + scheduling_period_ - (GetEnv()->NowMicros() - start_time_micros); + if (sleep_time > 0) { + GetEnv()->SleepForMicroseconds(sleep_time); + } +} + +template +bool AdaptiveSharedBatchScheduler::BatchCompare::operator()( + const internal::ASBSBatch* a, + const internal::ASBSBatch* b) { + return a->creation_time_micros() > b->creation_time_micros(); +} + +// ---------------- ASBSQueue ---------------- + +namespace internal { +template +ASBSQueue::ASBSQueue( + std::shared_ptr> scheduler, + const QueueOptions& options) + : scheduler_(scheduler), options_(options) {} + +template +ASBSQueue::~ASBSQueue() { + // Wait until last batch has been scheduled. + const int kSleepMicros = 1000; + for (;;) { + { + mutex_lock l(mu_); + if (num_enqueued_batches_ == 0) { + break; + } + } + scheduler_->GetEnv()->SleepForMicroseconds(kSleepMicros); + } + scheduler_->RemoveQueue(this); +} + +template +Status ASBSQueue::Schedule(std::unique_ptr* task) { + ASBSBatch* new_batch = nullptr; + size_t size = (*task)->size(); + if (size > options_.max_batch_size) { + return errors::InvalidArgument("Task size ", size, + " is larger than maximum batch size ", + options_.max_batch_size); + } + { + mutex_lock l(mu_); + // Current batch is full, create another if allowed. + if (current_batch_ && + current_batch_->size() + size > options_.max_batch_size) { + if (num_enqueued_batches_ >= options_.max_enqueued_batches) { + return errors::Unavailable("The batch scheduling queue is full"); + } + current_batch_->Close(); + current_batch_ = nullptr; + } + if (!current_batch_) { + num_enqueued_batches_++; + current_batch_ = new_batch = + new ASBSBatch(this, scheduler_->GetEnv()->NowMicros()); + } + current_batch_->AddTask(std::move(*task)); + num_enqueued_tasks_++; + } + if (new_batch != nullptr) scheduler_->AddBatch(new_batch); + return Status::OK(); +} + +template +void ASBSQueue::ReleaseBatch(const ASBSBatch* batch) { + mutex_lock l(mu_); + num_enqueued_batches_--; + num_enqueued_tasks_ -= batch->num_tasks(); + if (batch == current_batch_) { + current_batch_->Close(); + current_batch_ = nullptr; + } +} + +template +size_t ASBSQueue::NumEnqueuedTasks() const { + mutex_lock l(mu_); + return num_enqueued_tasks_; +} + +template +size_t ASBSQueue::SchedulingCapacity() const { + mutex_lock l(mu_); + const int current_batch_capacity = + current_batch_ ? options_.max_batch_size - current_batch_->size() : 0; + const int spare_batches = + options_.max_enqueued_batches - num_enqueued_batches_; + return spare_batches * options_.max_batch_size + current_batch_capacity; +} +} // namespace internal +} // namespace serving +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_ diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..a07cd6d834fa28904bf7748b16972cca217503c1 --- /dev/null +++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc @@ -0,0 +1,438 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h" + +#include "tensorflow/contrib/batching/test_util/fake_clock_env.h" +#include "tensorflow/core/lib/core/notification.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/macros.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace serving { +namespace anonymous { + +class FakeTask : public BatchTask { + public: + explicit FakeTask(size_t size) : size_(size) {} + + ~FakeTask() override = default; + + size_t size() const override { return size_; } + + private: + const size_t size_; + + TF_DISALLOW_COPY_AND_ASSIGN(FakeTask); +}; + +// Creates a FakeTask of size 'task_size', and calls 'scheduler->Schedule()' on +// that task. Returns the resulting status. +Status ScheduleTask(size_t task_size, BatchScheduler* scheduler) { + std::unique_ptr task(new FakeTask(task_size)); + Status status = scheduler->Schedule(&task); + // Schedule() should have consumed 'task' iff it returned Status::OK. + CHECK_EQ(status.ok(), task == nullptr); + return status; +} + +// Creates a thread that waits on 'start' and then advances the fake clock in +// 'env' in a loop until 'stop' is notified. Useful for allowing objects that +// use the clock to be destroyed. +std::unique_ptr CreateFakeClockAdvancerThread( + test_util::FakeClockEnv* env, Notification* start, Notification* stop) { + return std::unique_ptr(Env::Default()->StartThread( + {}, "FakeClockAdvancerThread", [env, start, stop] { + start->WaitForNotification(); + while (!stop->HasBeenNotified()) { + env->AdvanceByMicroseconds(10); + Env::Default()->SleepForMicroseconds(10); + } + })); +} + +TEST(AdaptiveSharedBatchSchedulerTest, Basic) { + for (const bool delete_scheduler_early : {false, true}) { + for (const bool delete_queue_1_early : {false, true}) { + int queue_0_tasks = 0; + auto queue_0_callback = + [&queue_0_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + }; + int queue_1_tasks = 0; + auto queue_1_callback = + [&queue_1_tasks](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + }; + { + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create({}, &scheduler)); + + // Create two queues. + std::unique_ptr> queue_0; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_0_callback, &queue_0)); + std::unique_ptr> queue_1; + TF_ASSERT_OK(scheduler->AddQueue({}, queue_1_callback, &queue_1)); + + if (delete_scheduler_early) { + // Delete our copy of the scheduler. The queues should keep it alive + // under the covers. + scheduler = nullptr; + } + // Submit tasks to the two queues, and (optionally) remove the queues. + TF_ASSERT_OK(ScheduleTask(1, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(2, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(3, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(4, queue_1.get())); + if (delete_queue_1_early) { + queue_1 = nullptr; + } + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + } + EXPECT_EQ(queue_0_tasks, 9); + EXPECT_EQ(queue_1_tasks, 6); + } + } +} + +TEST(AdaptiveSharedBatchSchedulerTest, BadOptions) { + using Scheduler = AdaptiveSharedBatchScheduler; + std::shared_ptr scheduler; + Scheduler::Options options; + options.num_batch_threads = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 50; + options.max_scheduling_period_micros = 100; + options.initial_scheduling_period_micros = 1000; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.min_scheduling_period_micros = 100; + options.max_scheduling_period_micros = 50; + options.initial_scheduling_period_micros = 75; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); + options = Scheduler::Options(); + options.feedback_smoothing_batches = 0; + EXPECT_FALSE(Scheduler::Create(options, &scheduler).ok()); +} + +TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue_0; + std::unique_ptr> queue_1; + int queue_0_tasks = 0; + int queue_1_tasks = 0; + auto queue_0_callback = [&queue_0_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_0_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + auto queue_1_callback = [&queue_1_tasks, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + for (int i = 0; i < batch->num_tasks(); i++) { + queue_1_tasks += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 0; + // Queue must have max_enqueued_batchs > 1. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0).ok()); + queue_options.max_enqueued_batches = 2; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_0_callback, &queue_0)); + queue_options.max_batch_size = 0; + // Queue must have max_batch_size > 0. + EXPECT_FALSE( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1).ok()); + queue_options.max_batch_size = 2; + queue_options.max_enqueued_batches = 1; + TF_ASSERT_OK( + scheduler->AddQueue(queue_options, queue_1_callback, &queue_1)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(15, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + env.AdvanceByMicroseconds(1); + + // Task larger than max_batch_size shouldn't schedule. + EXPECT_FALSE(ScheduleTask(3, queue_1.get()).ok()); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + TF_ASSERT_OK(ScheduleTask(1, queue_1.get())); + env.AdvanceByMicroseconds(1); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(1, queue_1.get()).ok()); + + TF_ASSERT_OK(ScheduleTask(5, queue_0.get())); + // Exceeds max_enqueued_batches, shouldn't schedule. + EXPECT_FALSE(ScheduleTask(6, queue_0.get()).ok()); + TF_ASSERT_OK(ScheduleTask(4, queue_0.get())); + + // Batches should be processed in order from oldest to newest. + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 0); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 10); + EXPECT_EQ(queue_1_tasks, 2); + + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(queue_0_tasks, 19); + EXPECT_EQ(queue_1_tasks, 2); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, RateFeedback) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.min_scheduling_period_micros = 200; + options.max_scheduling_period_micros = 2000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 1; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 6 batches. + for (int i = 0; i < 6; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -500; + env.AdvanceByMicroseconds(994); + env.BlockUntilThreadsAsleep(2); // scheduling period = 500 usec. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(500); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 901); + feedback = 0; + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 250 usec. + EXPECT_EQ(scheduled_items, 902); + feedback = 10000; // large feedback should hit max_scheduling_period. + env.AdvanceByMicroseconds(250); + env.BlockUntilThreadsAsleep(2); // scheduling period = 2000 usec. + EXPECT_EQ(scheduled_items, 903); + feedback = -10000; // large feedback should hit min_scheduling_period. + env.AdvanceByMicroseconds(1999); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 903); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); // scheduling period = 200 usec. + EXPECT_EQ(scheduled_items, 904); + env.AdvanceByMicroseconds(200); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 905); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, FeedbackSmoothing) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + double feedback = 0; + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + options.scheduling_period_feedback = [&feedback] { return feedback; }; + options.feedback_smoothing_batches = 3; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + + TF_ASSERT_OK(scheduler->AddQueue({}, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 4 batches. + for (int i = 0; i < 4; i++) { + TF_ASSERT_OK(ScheduleTask(900 + i, queue.get())); + env.AdvanceByMicroseconds(1); + } + feedback = -300; + env.AdvanceByMicroseconds(996); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 100, scheduling_period = 900. + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(899); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 900); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // ewma_feedback = 167, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 901); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 901); + feedback = 1000 / 3.; + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + // emwa_feedback = 0, scheduling_period = 750. + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(749); + // No callback scheduled, only scheduling thread sleeping. + env.BlockUntilThreadsAsleep(1); + EXPECT_EQ(scheduled_items, 902); + env.AdvanceByMicroseconds(1); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 903); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} + +TEST(AdaptiveSharedBatchSchedulerTest, QueueCapacityInfo) { + test_util::FakeClockEnv env(Env::Default()); + Notification start_teardown, stop_teardown; + std::unique_ptr teardown_thread = + CreateFakeClockAdvancerThread(&env, &start_teardown, &stop_teardown); + { + AdaptiveSharedBatchScheduler::Options options; + options.initial_scheduling_period_micros = 1000; + options.env = &env; + std::shared_ptr> scheduler; + TF_ASSERT_OK( + AdaptiveSharedBatchScheduler::Create(options, &scheduler)); + std::unique_ptr> queue; + int scheduled_items = 0; + auto queue_callback = [&scheduled_items, + &env](std::unique_ptr> batch) { + ASSERT_TRUE(batch->IsClosed()); + EXPECT_GT(batch->num_tasks(), 0); + scheduled_items = 0; + for (int i = 0; i < batch->num_tasks(); i++) { + scheduled_items += batch->task(i).size(); + } + env.SleepForMicroseconds(1); + }; + AdaptiveSharedBatchScheduler::QueueOptions queue_options; + queue_options.max_batch_size = 10; + queue_options.max_enqueued_batches = 10; + TF_ASSERT_OK(scheduler->AddQueue(queue_options, queue_callback, &queue)); + + // Wait for scheduling_thread to sleep. + env.BlockUntilThreadsAsleep(1); + // Enqueue 3 tasks. + EXPECT_EQ(queue->NumEnqueuedTasks(), 0); + EXPECT_EQ(queue->SchedulingCapacity(), 100); + TF_ASSERT_OK(ScheduleTask(5, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 1); + EXPECT_EQ(queue->SchedulingCapacity(), 95); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(6, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 2); + EXPECT_EQ(queue->SchedulingCapacity(), 84); + env.AdvanceByMicroseconds(1); + TF_ASSERT_OK(ScheduleTask(1, queue.get())); + EXPECT_EQ(queue->NumEnqueuedTasks(), 3); + EXPECT_EQ(queue->SchedulingCapacity(), 83); + + env.AdvanceByMicroseconds(998); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 5); + env.AdvanceByMicroseconds(1000); + env.BlockUntilThreadsAsleep(2); + EXPECT_EQ(scheduled_items, 7); + start_teardown.Notify(); + } + stop_teardown.Notify(); +} +} // namespace anonymous +} // namespace serving +} // namespace tensorflow diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h index 7c41ad88180badd37398f5bae057dcd0006922c3..a5072f439abad3c5db79a514a7f2baff0b021b39 100644 --- a/tensorflow/contrib/batching/batch_scheduler.h +++ b/tensorflow/contrib/batching/batch_scheduler.h @@ -78,7 +78,7 @@ template class Batch { public: Batch() = default; - ~Batch(); // Blocks until the batch is closed. + virtual ~Batch(); // Blocks until the batch is closed. // Appends 'task' to the batch. After calling AddTask(), the newly-added task // can be accessed via task(num_tasks()-1) or mutable_task(num_tasks()-1). diff --git a/tensorflow/contrib/bayesflow/BUILD b/tensorflow/contrib/bayesflow/BUILD index df3f93d3f0ec92e0ca2227bc33033c4cdc030a77..213ae01c3bf69adf7514ade560fd055b0bb3fe7d 100644 --- a/tensorflow/contrib/bayesflow/BUILD +++ b/tensorflow/contrib/bayesflow/BUILD @@ -3,12 +3,15 @@ # particularly useful for Bayesian inference. # APIs here are meant to evolve over time. +package(default_visibility = [ + "//learning/brain/contrib/bayesflow:__subpackages__", + "//tensorflow:__subpackages__", +]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -19,11 +22,16 @@ py_library( "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", @@ -34,81 +42,62 @@ py_library( ) cuda_py_test( - name = "csiszar_divergence_test", + name = "metropolis_hastings_test", size = "medium", - srcs = ["python/kernel_tests/csiszar_divergence_test.py"], + srcs = ["python/kernel_tests/metropolis_hastings_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python/ops/distributions", "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:gradients", - "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", - ], - tags = [ - "manual", # b/64490288 - "notap", - ], -) - -cuda_py_test( - name = "custom_grad_test", - size = "small", - srcs = ["python/kernel_tests/custom_grad_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:init_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], ) cuda_py_test( - name = "entropy_test", + name = "csiszar_divergence_test", size = "medium", - srcs = ["python/kernel_tests/entropy_test.py"], + srcs = ["python/kernel_tests/csiszar_divergence_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", "//tensorflow/python/ops/distributions", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", - "//tensorflow/python:variables", + ], + tags = [ + "manual", # b/64490288 + "notap", ], ) cuda_py_test( - name = "stochastic_variables_test", - size = "medium", - srcs = ["python/kernel_tests/stochastic_variables_test.py"], + name = "custom_grad_test", + size = "small", + srcs = ["python/kernel_tests/custom_grad_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:init_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], @@ -136,84 +125,23 @@ cuda_py_test( ) cuda_py_test( - name = "stochastic_graph_test", - size = "small", - srcs = ["python/kernel_tests/stochastic_graph_test.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:gradients", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "variational_inference_test", - size = "small", - srcs = ["python/kernel_tests/variational_inference_test.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "stochastic_tensor_test", - size = "small", - srcs = ["python/kernel_tests/stochastic_tensor_test.py"], - additional_deps = [ - ":bayesflow_py", - "//third_party/py/numpy", - "//tensorflow/contrib/distributions:distributions_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_test( - name = "stochastic_gradient_estimators_test", + name = "hmc_test", size = "medium", - srcs = ["python/kernel_tests/stochastic_gradient_estimators_test.py"], + srcs = ["python/kernel_tests/hmc_test.py"], additional_deps = [ ":bayesflow_py", "//third_party/py/numpy", "//tensorflow/contrib/distributions:distributions_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python/ops/distributions", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", - "//tensorflow/python:variables", - ], -) - -cuda_py_test( - name = "reinforce_simple_example", - size = "small", - srcs = ["examples/reinforce_simple/reinforce_simple_example.py"], - additional_deps = [ - ":bayesflow_py", - "//tensorflow:tensorflow_py", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", ], ) diff --git a/tensorflow/contrib/bayesflow/__init__.py b/tensorflow/contrib/bayesflow/__init__.py index 15c1614a671894f75831f822f6880df1e277ccbc..b98bc369542679b05169db092aee86e884ca1625 100644 --- a/tensorflow/contrib/bayesflow/__init__.py +++ b/tensorflow/contrib/bayesflow/__init__.py @@ -23,22 +23,16 @@ from __future__ import print_function # pylint: disable=unused-import,line-too-long from tensorflow.contrib.bayesflow.python.ops import csiszar_divergence from tensorflow.contrib.bayesflow.python.ops import custom_grad -from tensorflow.contrib.bayesflow.python.ops import entropy +from tensorflow.contrib.bayesflow.python.ops import hmc +from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings from tensorflow.contrib.bayesflow.python.ops import monte_carlo -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import stochastic_variables -from tensorflow.contrib.bayesflow.python.ops import variational_inference # pylint: enable=unused-import,line-too-long from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = ['csiszar_divergence', 'custom_grad', 'entropy', - 'monte_carlo', 'special_math', - 'stochastic_gradient_estimators', 'stochastic_graph', - 'stochastic_tensor', 'stochastic_variables', - 'variational_inference'] + 'metropolis_hastings', 'monte_carlo', 'hmc', 'special_math', + 'stochastic_variables', 'variational_inference'] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py b/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py deleted file mode 100644 index 2eb625487f4cd18bdec10ddbc0cf64cb8c8499b8..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/examples/reinforce_simple/reinforce_simple_example.py +++ /dev/null @@ -1,140 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Simple examples of the REINFORCE algorithm.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import tensorflow as tf - - -distributions = tf.contrib.distributions -sg = tf.contrib.bayesflow.stochastic_graph -st = tf.contrib.bayesflow.stochastic_tensor - - -def split_apply_merge(inp, partitions, fns): - """Split input according to partitions. Pass results through fns and merge. - - Args: - inp: the input vector - partitions: tensor of same length as input vector, having values 0, 1 - fns: the two functions. - - Returns: - the vector routed, where routed[i] = fns[partitions[i]](inp[i]) - """ - new_inputs = tf.dynamic_partition(inp, partitions, len(fns)) - new_outputs = [fns[i](x) for i, x in enumerate(new_inputs)] - new_indices = tf.dynamic_partition( - tf.range(0, inp.get_shape()[0]), partitions, len(fns)) - return tf.dynamic_stitch(new_indices, new_outputs) - - -def plus_1(inputs): - return inputs + 1.0 - - -def minus_1(inputs): - return inputs - 1.0 - - -def build_split_apply_merge_model(): - """Build the Split-Apply-Merge Model. - - Route each value of input [-1, -1, 1, 1] through one of the - functions, plus_1, minus_1. The decision for routing is made by - 4 Bernoulli R.V.s whose parameters are determined by a neural network - applied to the input. REINFORCE is used to update the NN parameters. - - Returns: - The 3-tuple (route_selection, routing_loss, final_loss), where: - - - route_selection is an int 4-vector - - routing_loss is a float 4-vector - - final_loss is a float scalar. - """ - inputs = tf.constant([[-1.0], [-1.0], [1.0], [1.0]]) - targets = tf.constant([[0.0], [0.0], [0.0], [0.0]]) - paths = [plus_1, minus_1] - weights = tf.get_variable("w", [1, 2]) - bias = tf.get_variable("b", [1, 1]) - logits = tf.matmul(inputs, weights) + bias - - # REINFORCE forward step - route_selection = st.StochasticTensor( - distributions.Categorical(logits=logits)) - - # Accessing route_selection as a Tensor below forces a sample of - # the Categorical distribution based on its logits. - # This is equivalent to calling route_selection.value(). - # - # route_selection.value() returns an int32 4-vector with random - # values in {0, 1} - # COPY+ROUTE+PASTE - outputs = split_apply_merge(inputs, route_selection, paths) - - # flatten routing_loss to a row vector (from a column vector) - routing_loss = tf.reshape(tf.square(outputs - targets), shape=[-1]) - - # Total loss: score function loss + routing loss. - # The score function loss (through `route_selection.loss(routing_loss)`) - # returns: - # [stop_gradient(routing_loss) * - # route_selection.log_pmf(stop_gradient(route_selection.value()))], - # where log_pmf has gradients going all the way back to weights and bias. - # In this case, the routing_loss depends on the variables only through - # "route_selection", which has a stop_gradient on it. So the - # gradient of the loss really come through the score function - surrogate_loss = sg.surrogate_loss([routing_loss]) - final_loss = tf.reduce_sum(surrogate_loss) - - return (route_selection, routing_loss, final_loss) - - -class REINFORCESimpleExample(tf.test.TestCase): - - def testSplitApplyMerge(self): - # Repeatability. SGD has a tendency to jump around, even here. - tf.set_random_seed(1) - - with self.test_session() as sess: - # Use sampling to train REINFORCE - with st.value_type(st.SampleValue()): - (route_selection, - routing_loss, - final_loss) = build_split_apply_merge_model() - - sgd = tf.train.GradientDescentOptimizer(1.0).minimize(final_loss) - - tf.global_variables_initializer().run() - - for i in range(10): - # Run loss and inference step. This toy problem converges VERY quickly. - (routing_loss_v, final_loss_v, route_selection_v, _) = sess.run( - [routing_loss, final_loss, tf.identity(route_selection), sgd]) - print( - "Iteration %d, routing loss: %s, final_loss: %s, " - "route selection: %s" - % (i, routing_loss_v, final_loss_v, route_selection_v)) - - self.assertAllEqual([0, 0, 1, 1], route_selection_v) - self.assertAllClose([0.0, 0.0, 0.0, 0.0], routing_loss_v) - self.assertAllClose(0.0, final_loss_v) - - -if __name__ == "__main__": - tf.test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py deleted file mode 100644 index 0bd12b84d12a9c3219f6b24830b1b82db9716043..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/entropy_test.py +++ /dev/null @@ -1,352 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for Monte Carlo Ops.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib.bayesflow.python.ops import entropy_impl as entropy -from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib -from tensorflow.contrib.distributions.python.ops import mvn_tril as mvn_tril_lib -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import kullback_leibler as kullback_leibler_lib -from tensorflow.python.ops.distributions import normal as normal_lib -from tensorflow.python.ops.distributions import util as distribution_util -from tensorflow.python.platform import test - -layers = layers_lib - - -class NormalNoEntropy(normal_lib.Normal): # pylint: disable=no-init - """Normal distribution without a `.entropy` method.""" - - def entropy(self): - return NotImplementedError('Entropy removed by gremlins') - - -def get_train_op(scalar_loss, optimizer='SGD', learning_rate=1.0, decay=0.0): - global_step = variables.Variable(0) - - def decay_fn(rate, t): - return rate * (1 + math_ops.to_float(t))**(-decay) - - train_op = layers.optimize_loss( - scalar_loss, - global_step, - learning_rate, - optimizer, - learning_rate_decay_fn=decay_fn) - return train_op - - -def _assert_monotonic_decreasing(array, atol=1e-5): - array = np.asarray(array) - _assert_monotonic_increasing(-array, atol=atol) - - -def _assert_monotonic_increasing(array, atol=1e-5): - array = np.asarray(array) - diff = np.diff(array.ravel()) - np.testing.assert_array_less(-1 * atol, diff) - - -class ElboRatioTest(test.TestCase): - """Show sampling converges to true KL values.""" - - def setUp(self): - self._rng = np.random.RandomState(0) - - def test_convergence_to_kl_using_sample_form_on_3dim_normal(self): - # Test that the sample mean KL is the same as analytic when we use samples - # to estimate every part of the KL divergence ratio. - vector_shape = (2, 3) - n_samples = 5000 - - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - p = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=p.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.sample, - seed=42) - actual_kl = kullback_leibler_lib.kl_divergence(q, p) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.05) - - def test_convergence_to_kl_using_analytic_entropy_form_on_3dim_normal(self): - # Test that the sample mean KL is the same as analytic when we use an - # analytic entropy combined with sampled cross-entropy. - n_samples = 5000 - - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - p = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=p.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.analytic_entropy, - seed=42) - actual_kl = kullback_leibler_lib.kl_divergence(q, p) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(actual_kl.eval(), sample_kl.eval(), rtol=0.1) - - def test_sample_kl_zero_when_p_and_q_are_the_same_distribution(self): - n_samples = 50 - - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - - # In this case, the log_ratio is the KL. - sample_kl = -1 * entropy.elbo_ratio( - log_p=q.log_prob, - q=q, - n=n_samples, - form=entropy.ELBOForms.sample, - seed=42) - - self.assertEqual((2,), sample_kl.get_shape()) - self.assertAllClose(np.zeros(2), sample_kl.eval()) - - -class EntropyShannonTest(test.TestCase): - - def test_normal_entropy_default_form_uses_exact_entropy(self): - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon(dist, n=11) - exact_entropy = dist.entropy() - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval()) - - def test_normal_entropy_analytic_form_uses_exact_entropy(self): - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon( - dist, form=entropy.ELBOForms.analytic_entropy) - exact_entropy = dist.entropy() - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval()) - - def test_normal_entropy_sample_form_gets_approximate_answer(self): - # Tested by showing we get a good answer that is not exact. - with self.test_session(): - dist = normal_lib.Normal(loc=1.11, scale=2.22) - mc_entropy = entropy.entropy_shannon( - dist, n=1000, form=entropy.ELBOForms.sample, seed=0) - exact_entropy = dist.entropy() - - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval(), rtol=0.01) - - # Make sure there is some error, proving we used samples - self.assertLess(0.0001, math_ops.abs(exact_entropy - mc_entropy).eval()) - - def test_default_entropy_falls_back_on_sample_if_analytic_not_available(self): - # Tested by showing we get a good answer that is not exact. - with self.test_session(): - # NormalNoEntropy is like a Normal, but does not have .entropy method, so - # we are forced to fall back on sample entropy. - dist_no_entropy = NormalNoEntropy(loc=1.11, scale=2.22) - dist_yes_entropy = normal_lib.Normal(loc=1.11, scale=2.22) - - mc_entropy = entropy.entropy_shannon( - dist_no_entropy, n=1000, form=entropy.ELBOForms.sample, seed=0) - exact_entropy = dist_yes_entropy.entropy() - - self.assertEqual(exact_entropy.get_shape(), mc_entropy.get_shape()) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(exact_entropy.eval(), mc_entropy.eval(), rtol=0.01) - - # Make sure there is some error, proving we used samples - self.assertLess(0.0001, math_ops.abs(exact_entropy - mc_entropy).eval()) - - -class RenyiRatioTest(test.TestCase): - """Show renyi_ratio is minimized when the distributions match.""" - - def setUp(self): - self._rng = np.random.RandomState(0) - - def test_fitting_two_dimensional_normal_n_equals_1000(self): - # Minmizing Renyi divergence should allow us to make one normal match - # another one exactly. - n = 1000 - mu_true = np.array([1.0, -1.0], dtype=np.float64) - chol_true = np.array([[2.0, 0.0], [0.5, 1.0]], dtype=np.float64) - with self.test_session() as sess: - target = mvn_tril_lib.MultivariateNormalTriL(mu_true, chol_true) - - # Set up q distribution by defining mean/covariance as Variables - mu = variables.Variable( - np.zeros(mu_true.shape), dtype=mu_true.dtype, name='mu') - mat = variables.Variable( - np.zeros(chol_true.shape), dtype=chol_true.dtype, name='mat') - chol = distribution_util.matrix_diag_transform( - mat, transform=nn_ops.softplus) - q = mvn_tril_lib.MultivariateNormalTriL(mu, chol) - for alpha in [0.25, 0.75]: - - negative_renyi_divergence = entropy.renyi_ratio( - log_p=target.log_prob, q=q, n=n, alpha=alpha, seed=0) - train_op = get_train_op( - math_ops.reduce_mean(-negative_renyi_divergence), - optimizer='SGD', - learning_rate=0.5, - decay=0.1) - - variables.global_variables_initializer().run() - renyis = [] - for step in range(1000): - sess.run(train_op) - if step in [1, 5, 100]: - renyis.append(negative_renyi_divergence.eval()) - - # This optimization should maximize the renyi divergence. - _assert_monotonic_increasing(renyis, atol=0) - - # Relative tolerance (rtol) chosen 2 times as large as minimim needed to - # pass. - self.assertAllClose(target.loc.eval(), q.loc.eval(), rtol=0.06) - self.assertAllClose(target.scale.to_dense().eval(), - q.scale.to_dense().eval(), - rtol=0.1) - - def test_divergence_between_identical_distributions_is_zero(self): - n = 1000 - vector_shape = (2, 3) - with self.test_session(): - q = mvn_diag_lib.MultivariateNormalDiag( - loc=self._rng.rand(*vector_shape), - scale_diag=self._rng.rand(*vector_shape)) - for alpha in [0.25, 0.75]: - - negative_renyi_divergence = entropy.renyi_ratio( - log_p=q.log_prob, q=q, n=n, alpha=alpha, seed=0) - - self.assertEqual((2,), negative_renyi_divergence.get_shape()) - self.assertAllClose(np.zeros(2), negative_renyi_divergence.eval()) - - -class RenyiAlphaTest(test.TestCase): - - def test_with_three_alphas(self): - with self.test_session(): - for dtype in (dtypes.float32, dtypes.float64): - alpha_min = constant_op.constant(0.0, dtype=dtype) - alpha_max = 0.5 - decay_time = 3 - - alpha_0 = entropy.renyi_alpha( - 0, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_1 = entropy.renyi_alpha( - 1, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_2 = entropy.renyi_alpha( - 2, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - alpha_3 = entropy.renyi_alpha( - 3, decay_time, alpha_min=alpha_min, alpha_max=alpha_max) - - # Alpha should start at alpha_max. - self.assertAllClose(alpha_max, alpha_0.eval(), atol=1e-5) - # Alpha should finish at alpha_min. - self.assertAllClose(alpha_min.eval(), alpha_3.eval(), atol=1e-5) - # In between, alpha should be monotonically decreasing. - _assert_monotonic_decreasing( - [alpha_0.eval(), alpha_1.eval(), alpha_2.eval(), alpha_3.eval()]) - - def test_non_scalar_input_raises(self): - with self.test_session(): - # Good values here - step = 0 - alpha_min = 0.0 - alpha_max = 0.5 - decay_time = 3 - - # Use one bad value inside each check. - # The "bad" value is always the non-scalar one. - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - [step], decay_time, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, [decay_time], alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, decay_time, alpha_min=[alpha_min], alpha_max=alpha_max).eval() - - with self.assertRaisesRegexp(ValueError, 'must be scalar'): - entropy.renyi_alpha( - step, decay_time, alpha_min=alpha_min, alpha_max=[alpha_max]).eval() - - def test_input_with_wrong_sign_raises(self): - with self.test_session(): - # Good values here - step = 0 - alpha_min = 0.0 - alpha_max = 0.5 - decay_time = 3 - - # Use one bad value inside each check. - # The "bad" value is always the non-scalar one. - with self.assertRaisesOpError('decay_time must be positive'): - entropy.renyi_alpha( - step, 0.0, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - with self.assertRaisesOpError('step must be non-negative'): - entropy.renyi_alpha( - -1, decay_time, alpha_min=alpha_min, alpha_max=alpha_max).eval() - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f108e5f01e4945ee83d8262f1d99877f0fe9f0 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/hmc_test.py @@ -0,0 +1,349 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Hamiltonian Monte Carlo. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import special +from scipy import stats + +from tensorflow.contrib.bayesflow.python.ops import hmc + +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging + + +# TODO(b/66964210): Test float16. +class HMCTest(test.TestCase): + + def setUp(self): + self._shape_param = 5. + self._rate_param = 10. + self._expected_x = (special.digamma(self._shape_param) + - np.log(self._rate_param)) + self._expected_exp_x = self._shape_param / self._rate_param + + random_seed.set_random_seed(10003) + np.random.seed(10003) + + def _log_gamma_log_prob(self, x, event_dims=()): + """Computes log-pdf of a log-gamma random variable. + + Args: + x: Value of the random variable. + event_dims: Dimensions not to treat as independent. + + Returns: + log_prob: The log-pdf up to a normalizing constant. + """ + return math_ops.reduce_sum(self._shape_param * x - + self._rate_param * math_ops.exp(x), + event_dims) + + def _log_gamma_log_prob_grad(self, x, event_dims=()): + """Computes log-pdf and gradient of a log-gamma random variable. + + Args: + x: Value of the random variable. + event_dims: Dimensions not to treat as independent. Default is (), + i.e., all dimensions are independent. + + Returns: + log_prob: The log-pdf up to a normalizing constant. + grad: The gradient of the log-pdf with respect to x. + """ + return (math_ops.reduce_sum(self._shape_param * x - + self._rate_param * math_ops.exp(x), + event_dims), + self._shape_param - self._rate_param * math_ops.exp(x)) + + def _n_event_dims(self, x_shape, event_dims): + return np.prod([int(x_shape[i]) for i in event_dims]) + + def _integrator_conserves_energy(self, x, event_dims, sess, + feed_dict=None): + def potential_and_grad(x): + log_prob, grad = self._log_gamma_log_prob_grad(x, event_dims) + return -log_prob, -grad + + step_size = array_ops.placeholder(np.float32, [], name='step_size') + hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + + if feed_dict is None: + feed_dict = {} + feed_dict[hmc_lf_steps] = 1000 + + m = random_ops.random_normal(array_ops.shape(x)) + potential_0, grad_0 = potential_and_grad(x) + old_energy = potential_0 + 0.5 * math_ops.reduce_sum(m * m, + event_dims) + + _, new_m, potential_1, _ = ( + hmc.leapfrog_integrator(step_size, hmc_lf_steps, x, + m, potential_and_grad, grad_0)) + + new_energy = potential_1 + 0.5 * math_ops.reduce_sum(new_m * new_m, + event_dims) + + x_shape = sess.run(x, feed_dict).shape + n_event_dims = self._n_event_dims(x_shape, event_dims) + feed_dict[step_size] = 0.1 / n_event_dims + old_energy_val, new_energy_val = sess.run([old_energy, new_energy], + feed_dict) + logging.vlog(1, 'average energy change: {}'.format( + abs(old_energy_val - new_energy_val).mean())) + + self.assertAllEqual(np.ones_like(new_energy_val, dtype=np.bool), + abs(old_energy_val - new_energy_val) < 1.) + + def _integrator_conserves_energy_wrapper(self, event_dims): + """Tests the long-term energy conservation of the leapfrog integrator. + + The leapfrog integrator is symplectic, so for sufficiently small step + sizes it should be possible to run it more or less indefinitely without + the energy of the system blowing up or collapsing. + + Args: + event_dims: A tuple of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + """ + with self.test_session() as sess: + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + feed_dict = {x_ph: np.zeros([50, 10, 2])} + self._integrator_conserves_energy(x_ph, event_dims, sess, feed_dict) + + def testIntegratorEnergyConservationNullShape(self): + self._integrator_conserves_energy_wrapper([]) + + def testIntegratorEnergyConservation1(self): + self._integrator_conserves_energy_wrapper([1]) + + def testIntegratorEnergyConservation2(self): + self._integrator_conserves_energy_wrapper([2]) + + def testIntegratorEnergyConservation12(self): + self._integrator_conserves_energy_wrapper([1, 2]) + + def testIntegratorEnergyConservation012(self): + self._integrator_conserves_energy_wrapper([0, 1, 2]) + + def _chain_gets_correct_expectations(self, x, event_dims, sess, + feed_dict=None): + def log_gamma_log_prob(x): + return self._log_gamma_log_prob(x, event_dims) + + step_size = array_ops.placeholder(np.float32, [], name='step_size') + hmc_lf_steps = array_ops.placeholder(np.int32, [], name='hmc_lf_steps') + hmc_n_steps = array_ops.placeholder(np.int32, [], name='hmc_n_steps') + + if feed_dict is None: + feed_dict = {} + feed_dict.update({step_size: 0.1, + hmc_lf_steps: 2, + hmc_n_steps: 300}) + + sample_chain, acceptance_prob_chain = hmc.chain([hmc_n_steps], + step_size, + hmc_lf_steps, + x, log_gamma_log_prob, + event_dims) + + acceptance_probs, samples = sess.run([acceptance_prob_chain, sample_chain], + feed_dict) + samples = samples[feed_dict[hmc_n_steps] // 2:] + expected_x_est = samples.mean() + expected_exp_x_est = np.exp(samples).mean() + + logging.vlog(1, 'True E[x, exp(x)]: {}\t{}'.format( + self._expected_x, self._expected_exp_x)) + logging.vlog(1, 'Estimated E[x, exp(x)]: {}\t{}'.format( + expected_x_est, expected_exp_x_est)) + self.assertNear(expected_x_est, self._expected_x, 2e-2) + self.assertNear(expected_exp_x_est, self._expected_exp_x, 2e-2) + self.assertTrue((acceptance_probs > 0.5).all()) + self.assertTrue((acceptance_probs <= 1.0).all()) + + def _chain_gets_correct_expectations_wrapper(self, event_dims): + with self.test_session() as sess: + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + feed_dict = {x_ph: np.zeros([50, 10, 2])} + self._chain_gets_correct_expectations(x_ph, event_dims, sess, + feed_dict) + + def testHMCChainExpectationsNullShape(self): + self._chain_gets_correct_expectations_wrapper([]) + + def testHMCChainExpectations1(self): + self._chain_gets_correct_expectations_wrapper([1]) + + def testHMCChainExpectations2(self): + self._chain_gets_correct_expectations_wrapper([2]) + + def testHMCChainExpectations12(self): + self._chain_gets_correct_expectations_wrapper([1, 2]) + + def _kernel_leaves_target_invariant(self, initial_draws, event_dims, + sess, feed_dict=None): + def log_gamma_log_prob(x): + return self._log_gamma_log_prob(x, event_dims) + + def fake_log_prob(x): + """Cooled version of the target distribution.""" + return 1.1 * log_gamma_log_prob(x) + + step_size = array_ops.placeholder(np.float32, [], name='step_size') + + if feed_dict is None: + feed_dict = {} + + feed_dict[step_size] = 0.4 + + sample, acceptance_probs, _, _ = hmc.kernel(step_size, 5, initial_draws, + log_gamma_log_prob, event_dims) + bad_sample, bad_acceptance_probs, _, _ = hmc.kernel( + step_size, 5, initial_draws, fake_log_prob, event_dims) + (acceptance_probs_val, bad_acceptance_probs_val, initial_draws_val, + updated_draws_val, fake_draws_val) = sess.run([acceptance_probs, + bad_acceptance_probs, + initial_draws, sample, + bad_sample], feed_dict) + # Confirm step size is small enough that we usually accept. + self.assertGreater(acceptance_probs_val.mean(), 0.5) + self.assertGreater(bad_acceptance_probs_val.mean(), 0.5) + # Confirm step size is large enough that we sometimes reject. + self.assertLess(acceptance_probs_val.mean(), 0.99) + self.assertLess(bad_acceptance_probs_val.mean(), 0.99) + _, ks_p_value_true = stats.ks_2samp(initial_draws_val.flatten(), + updated_draws_val.flatten()) + _, ks_p_value_fake = stats.ks_2samp(initial_draws_val.flatten(), + fake_draws_val.flatten()) + logging.vlog(1, 'acceptance rate for true target: {}'.format( + acceptance_probs_val.mean())) + logging.vlog(1, 'acceptance rate for fake target: {}'.format( + bad_acceptance_probs_val.mean())) + logging.vlog(1, 'K-S p-value for true target: {}'.format(ks_p_value_true)) + logging.vlog(1, 'K-S p-value for fake target: {}'.format(ks_p_value_fake)) + # Make sure that the MCMC update hasn't changed the empirical CDF much. + self.assertGreater(ks_p_value_true, 1e-3) + # Confirm that targeting the wrong distribution does + # significantly change the empirical CDF. + self.assertLess(ks_p_value_fake, 1e-6) + + def _kernel_leaves_target_invariant_wrapper(self, event_dims): + """Tests that the kernel leaves the target distribution invariant. + + Draws some independent samples from the target distribution, + applies an iteration of the MCMC kernel, then runs a + Kolmogorov-Smirnov test to determine if the distribution of the + MCMC-updated samples has changed. + + We also confirm that running the kernel with a different log-pdf + does change the target distribution. (And that we can detect that.) + + Args: + event_dims: A tuple of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + """ + with self.test_session() as sess: + initial_draws = np.log(np.random.gamma(self._shape_param, + size=[50000, 2, 2])) + initial_draws -= np.log(self._rate_param) + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + feed_dict = {x_ph: initial_draws} + + self._kernel_leaves_target_invariant(x_ph, event_dims, sess, + feed_dict) + + def testKernelLeavesTargetInvariantNullShape(self): + self._kernel_leaves_target_invariant_wrapper([]) + + def testKernelLeavesTargetInvariant1(self): + self._kernel_leaves_target_invariant_wrapper([1]) + + def testKernelLeavesTargetInvariant2(self): + self._kernel_leaves_target_invariant_wrapper([2]) + + def testKernelLeavesTargetInvariant12(self): + self._kernel_leaves_target_invariant_wrapper([1, 2]) + + def _ais_gets_correct_log_normalizer(self, init, event_dims, sess, + feed_dict=None): + def proposal_log_prob(x): + return math_ops.reduce_sum(-0.5 * x * x - 0.5 * np.log(2*np.pi), + event_dims) + + def target_log_prob(x): + return self._log_gamma_log_prob(x, event_dims) + + if feed_dict is None: + feed_dict = {} + + w, _, _ = hmc.ais_chain(200, 0.5, 2, init, target_log_prob, + proposal_log_prob, event_dims) + + w_val = sess.run(w, feed_dict) + init_shape = sess.run(init, feed_dict).shape + normalizer_multiplier = np.prod([init_shape[i] for i in event_dims]) + + true_normalizer = -self._shape_param * np.log(self._rate_param) + true_normalizer += special.gammaln(self._shape_param) + true_normalizer *= normalizer_multiplier + + n_weights = np.prod(w_val.shape) + normalized_w = np.exp(w_val - true_normalizer) + standard_error = np.std(normalized_w) / np.sqrt(n_weights) + logging.vlog(1, 'True normalizer {}, estimated {}, n_weights {}'.format( + true_normalizer, np.log(normalized_w.mean()) + true_normalizer, + n_weights)) + self.assertNear(normalized_w.mean(), 1.0, 4.0 * standard_error) + + def _ais_gets_correct_log_normalizer_wrapper(self, event_dims): + """Tests that AIS yields reasonable estimates of normalizers.""" + with self.test_session() as sess: + x_ph = array_ops.placeholder(np.float32, name='x_ph') + + initial_draws = np.random.normal(size=[30, 2, 1]) + feed_dict = {x_ph: initial_draws} + + self._ais_gets_correct_log_normalizer(x_ph, event_dims, sess, + feed_dict) + + def testAISNullShape(self): + self._ais_gets_correct_log_normalizer_wrapper([]) + + def testAIS1(self): + self._ais_gets_correct_log_normalizer_wrapper([1]) + + def testAIS2(self): + self._ais_gets_correct_log_normalizer_wrapper([2]) + + def testAIS12(self): + self._ais_gets_correct_log_normalizer_wrapper([1, 2]) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py new file mode 100644 index 0000000000000000000000000000000000000000..63d93fad64d077aa385b72428665e841b6784b90 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/kernel_tests/metropolis_hastings_test.py @@ -0,0 +1,179 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for metropolis_hastings.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.bayesflow.python.ops import metropolis_hastings_impl as mh +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class McmcStepTest(test.TestCase): + + def test_density_increasing_step_accepted(self): + """Tests that if a transition increases density, it is always accepted.""" + target_log_density = lambda x: - x * x + state = variable_scope.get_variable('state', initializer=10.) + state_log_density = variable_scope.get_variable( + 'state_log_density', + initializer=target_log_density(state.initialized_value())) + log_accept_ratio = variable_scope.get_variable( + 'log_accept_ratio', initializer=0.) + + get_next_proposal = lambda x: (x - 1., None) + step = mh.evolve(state, state_log_density, log_accept_ratio, + target_log_density, get_next_proposal, seed=1234) + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + for j in range(9): + sess.run(step) + sample = sess.run(state) + sample_log_density = sess.run(state_log_density) + self.assertAlmostEqual(sample, 9 - j) + self.assertAlmostEqual(sample_log_density, - (9 - j) * (9 - j)) + + def test_sample_properties(self): + """Tests that the samples converge to the target distribution.""" + + def target_log_density(x): + """Log-density corresponding to a normal distribution with mean = 4.""" + return - (x - 2.0) * (x - 2.0) * 0.5 + + # Use the uniform random walker to generate proposals. + proposal_fn = mh.uniform_random_proposal( + step_size=1.0, seed=1234) + + state = variable_scope.get_variable('state', initializer=0.0) + state_log_density = variable_scope.get_variable( + 'state_log_density', + initializer=target_log_density(state.initialized_value())) + + log_accept_ratio = variable_scope.get_variable( + 'log_accept_ratio', initializer=0.) + # Random walk MCMC converges slowly so need to put in enough iterations. + num_iterations = 5000 + step = mh.evolve(state, state_log_density, log_accept_ratio, + target_log_density, proposal_fn, seed=4321) + + init = variables.global_variables_initializer() + + sample_sum, sample_sq_sum = 0.0, 0.0 + with self.test_session() as sess: + sess.run(init) + for _ in np.arange(num_iterations): + # Allow for the mixing of the chain and discard these samples. + sess.run(step) + for _ in np.arange(num_iterations): + sess.run(step) + sample = sess.run(state) + sample_sum += sample + sample_sq_sum += sample * sample + + sample_mean = sample_sum / num_iterations + sample_variance = sample_sq_sum / num_iterations - sample_mean * sample_mean + # The samples have large autocorrelation which reduces the effective sample + # size. + self.assertAlmostEqual(sample_mean, 2.0, delta=0.1) + self.assertAlmostEqual(sample_variance, 1.0, delta=0.1) + + def test_normal_proposals(self): + """Tests that the normal proposals are correctly distributed.""" + + initial_points = array_ops.ones([10000], dtype=dtypes.float32) + proposal_fn = mh.normal_random_proposal( + scale=2.0, seed=1234) + proposal_points, _ = proposal_fn(initial_points) + + with self.test_session() as sess: + sample = sess.run(proposal_points) + + # It is expected that the elements in proposal_points have the same mean as + # initial_points and have the standard deviation that was supplied to the + # proposal scheme. + self.assertAlmostEqual(np.mean(sample), 1.0, delta=0.1) + self.assertAlmostEqual(np.std(sample), 2.0, delta=0.1) + + def test_docstring_example(self): + """Tests the simplified docstring example with multiple chains.""" + + n = 2 # dimension of the problem + + # Generate 300 initial values randomly. Each of these would be an + # independent starting point for a Markov chain. + state = variable_scope.get_variable( + 'state', initializer=random_ops.random_normal( + [300, n], mean=3.0, dtype=dtypes.float32, seed=42)) + + # Computes the log(p(x)) for the unit normal density and ignores the + # normalization constant. + def log_density(x): + return - math_ops.reduce_sum(x * x, reduction_indices=-1) / 2.0 + + # Initial log-density value + state_log_density = variable_scope.get_variable( + 'state_log_density', + initializer=log_density(state.initialized_value())) + + # A variable to store the log_acceptance_ratio: + log_acceptance_ratio = variable_scope.get_variable( + 'log_acceptance_ratio', + initializer=array_ops.zeros([300], dtype=dtypes.float32)) + + # Generates random proposals by moving each coordinate uniformly and + # independently in a box of size 2 centered around the current value. + # Returns the new point and also the log of the Hastings ratio (the + # ratio of the probability of going from the proposal to origin and the + # probability of the reverse transition). When this ratio is 1, the value + # may be omitted and replaced by None. + def random_proposal(x): + return (x + random_ops.random_uniform( + array_ops.shape(x), minval=-1, maxval=1, + dtype=x.dtype, seed=12)), None + + # Create the op to propagate the chain for 100 steps. + stepper = mh.evolve( + state, state_log_density, log_acceptance_ratio, + log_density, random_proposal, n_steps=100, seed=123) + init = variables.initialize_all_variables() + with self.test_session() as sess: + sess.run(init) + # Run the chains for a total of 1000 steps. + for _ in range(10): + sess.run(stepper) + samples = sess.run(state) + covariance = np.eye(n) + # Verify that the estimated mean and covariance are close to the true + # values. + self.assertAlmostEqual( + np.max(np.abs(np.mean(samples, 0) + - np.zeros(n))), 0, + delta=0.1) + self.assertAlmostEqual( + np.max(np.abs(np.reshape(np.cov(samples, rowvar=False), [n**2]) + - np.reshape(covariance, [n**2]))), 0, + delta=0.2) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py deleted file mode 100644 index 9b1f482b34967082d6ac44494123879fb8fb0ee3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_gradient_estimators_test.py +++ /dev/null @@ -1,206 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib import distributions -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import gradient_checker -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -st = stochastic_tensor -sge = stochastic_gradient_estimators -dists = distributions - - -def _vimco(loss): - """Python implementation of VIMCO.""" - n = loss.shape[0] - log_loss = np.log(loss) - geometric_mean = [] - for j in range(n): - geometric_mean.append( - np.exp(np.mean([log_loss[i, :] for i in range(n) if i != j], 0))) - geometric_mean = np.array(geometric_mean) - - learning_signal = [] - for j in range(n): - learning_signal.append(np.sum([loss[i, :] for i in range(n) if i != j], 0)) - learning_signal = np.array(learning_signal) - - local_learning_signal = np.log(1 / n * (learning_signal + geometric_mean)) - - # log_mean - local_learning_signal - log_mean = np.log(np.mean(loss, 0)) - advantage = log_mean - local_learning_signal - - return advantage - - -class StochasticGradientEstimatorsTest(test.TestCase): - - def setUp(self): - self._p = constant_op.constant(0.999999) - self._final_loss = constant_op.constant(3.2) - - def _testScoreFunction(self, loss_fn, expected): - x = st.StochasticTensor(dists.Bernoulli(probs=self._p), loss_fn=loss_fn) - sf = x.loss(self._final_loss) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllClose(*sess.run([expected, sf])) - - def testScoreFunction(self): - expected = math_ops.log(self._p) * self._final_loss - self._testScoreFunction(sge.score_function, expected) - - def testScoreFunctionWithConstantBaseline(self): - b = constant_op.constant(9.8) - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_constant_baseline(b), expected) - - def testScoreFunctionWithBaselineFn(self): - b = constant_op.constant(9.8) - - def baseline_fn(stoch_tensor, loss): - self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) - self.assertTrue(isinstance(loss, ops.Tensor)) - return b - - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_baseline(baseline_fn), expected) - - def testScoreFunctionWithMeanBaseline(self): - ema_decay = 0.8 - num_steps = 6 - x = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - sf = x.loss(self._final_loss) - - # Expected EMA value - ema = 0. - for _ in range(num_steps): - ema -= (1. - ema_decay) * (ema - self._final_loss) - - # Baseline is EMA with bias correction - bias_correction = 1. - ema_decay**num_steps - baseline = ema / bias_correction - expected = math_ops.log(self._p) * (self._final_loss - baseline) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - for _ in range(num_steps - 1): - sess.run(sf) # run to update EMA - self.assertAllClose(*sess.run([expected, sf])) - - def testScoreFunctionWithAdvantageFn(self): - b = constant_op.constant(9.8) - - def advantage_fn(stoch_tensor, loss): - self.assertTrue(isinstance(stoch_tensor, st.StochasticTensor)) - self.assertTrue(isinstance(loss, ops.Tensor)) - return loss - b - - expected = math_ops.log(self._p) * (self._final_loss - b) - self._testScoreFunction( - sge.get_score_function_with_advantage(advantage_fn), expected) - - def testVIMCOAdvantageFn(self): - # simple_loss: (3, 2) with 3 samples, batch size 2 - simple_loss = np.array( - [[1.0, 1.5], - [1e-6, 1e4], - [2.0, 3.0]]) - # random_loss: (100, 50, 64) with 100 samples, batch shape (50, 64) - random_loss = 100 * np.random.rand(100, 50, 64) - - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=False) - - with self.test_session() as sess: - for loss in [simple_loss, random_loss]: - expected = _vimco(loss) - loss_t = constant_op.constant(loss, dtype=dtypes.float32) - advantage_t = advantage_fn(None, loss_t) # ST is not used - advantage = sess.run(advantage_t) - self.assertEqual(expected.shape, advantage_t.get_shape()) - self.assertAllClose(expected, advantage, atol=5e-5) - - def testVIMCOAdvantageGradients(self): - loss = np.log( - [[1.0, 1.5], - [1e-6, 1e4], - [2.0, 3.0]]) - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) - - with self.test_session(): - loss_t = constant_op.constant(loss, dtype=dtypes.float64) - advantage_t = advantage_fn(None, loss_t) # ST is not used - gradient_error = gradient_checker.compute_gradient_error( - loss_t, - loss_t.get_shape().as_list(), - advantage_t, - advantage_t.get_shape().as_list(), - x_init_value=loss) - self.assertLess(gradient_error, 1e-3) - - def testVIMCOAdvantageWithSmallProbabilities(self): - theta_value = np.random.rand(10, 100000) - # Test with float16 dtype to ensure stability even in this extreme case. - theta = constant_op.constant(theta_value, dtype=dtypes.float16) - advantage_fn = sge.get_vimco_advantage_fn(have_log_loss=True) - - with self.test_session() as sess: - log_loss = -math_ops.reduce_sum(theta, [1]) - advantage_t = advantage_fn(None, log_loss) - grad_t = gradients_impl.gradients(advantage_t, theta)[0] - advantage, grad = sess.run((advantage_t, grad_t)) - self.assertTrue(np.all(np.isfinite(advantage))) - self.assertTrue(np.all(np.isfinite(grad))) - - def testScoreFunctionWithMeanBaselineHasUniqueVarScope(self): - ema_decay = 0.8 - x = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - y = st.StochasticTensor( - dists.Bernoulli(probs=self._p), - loss_fn=sge.get_score_function_with_baseline( - sge.get_mean_baseline(ema_decay))) - sf_x = x.loss(self._final_loss) - sf_y = y.loss(self._final_loss) - with self.test_session() as sess: - # Smoke test - sess.run(variables.global_variables_initializer()) - sess.run([sf_x, sf_y]) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py deleted file mode 100644 index 44e27db03b18d0e6a789db676bea684c10dcfca7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_graph_test.py +++ /dev/null @@ -1,246 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import distributions as distributions_lib -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gradients_impl -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - -st = stochastic_tensor -sg = stochastic_graph_impl -distributions = distributions_lib - - -class NormalNotParam(distributions.Normal): - - @property - def reparameterization_type(self): - return distributions.NOT_REPARAMETERIZED - - -class TestSurrogateLosses(test.TestCase): - - def testPathwiseDerivativeDoesNotAddSurrogateLosses(self): - with self.test_session(): - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(distributions.Normal(loc=mu, scale=sigma)) - likelihood = st.StochasticTensor( - distributions.Normal( - loc=prior, scale=sigma)) - self.assertEqual( - prior.distribution.reparameterization_type, - distributions.FULLY_REPARAMETERIZED) - self.assertEqual( - likelihood.distribution.reparameterization_type, - distributions.FULLY_REPARAMETERIZED) - - loss = math_ops.square(array_ops.identity(likelihood) - [0.0, 0.1, 0.2]) - sum_loss = math_ops.reduce_sum(loss) - - surrogate_loss = sg.surrogate_loss([loss]) - with self.assertRaisesRegexp(ValueError, "dimensionality 1 or greater"): - _ = sg.surrogate_loss([sum_loss]) - surrogate_from_both = sg.surrogate_loss( - [loss, sum_loss * array_ops.ones_like(loss)]) - - # Pathwise derivative terms do not require add'l surrogate loss terms. - with self.test_session() as sess: - self.assertAllClose(*sess.run([loss, surrogate_loss])) - self.assertAllClose(*sess.run([(loss + sum_loss), surrogate_from_both])) - - def _testSurrogateLoss(self, session, losses, expected_addl_terms, xs): - surrogate_loss = sg.surrogate_loss(losses) - expected_surrogate_loss = math_ops.add_n(losses + expected_addl_terms) - self.assertAllClose(*session.run([surrogate_loss, expected_surrogate_loss])) - - # Test backprop - expected_grads = gradients_impl.gradients(ys=expected_surrogate_loss, xs=xs) - surrogate_grads = gradients_impl.gradients(ys=surrogate_loss, xs=xs) - self.assertEqual(len(expected_grads), len(surrogate_grads)) - grad_values = session.run(expected_grads + surrogate_grads) - n_grad = len(expected_grads) - self.assertAllClose(grad_values[:n_grad], grad_values[n_grad:]) - - def testSurrogateLoss(self): - with self.test_session() as sess: - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - likelihood = st.StochasticTensor(NormalNotParam(loc=prior, scale=sigma)) - prior_2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - - loss = math_ops.square(array_ops.identity(likelihood) - mu) - part_loss = math_ops.square(array_ops.identity(prior) - mu) - sum_loss = math_ops.reduce_sum(loss) - loss_nodeps = math_ops.square(array_ops.identity(prior_2) - mu) - - # For ground truth, use the stop-gradient versions of the losses - loss_nograd = array_ops.stop_gradient(loss) - loss_nodeps_nograd = array_ops.stop_gradient(loss_nodeps) - sum_loss_nograd = array_ops.stop_gradient(sum_loss) - - # These score functions should ignore prior_2 - self._testSurrogateLoss( - session=sess, - losses=[loss], - expected_addl_terms=[ - likelihood.distribution.log_prob( - likelihood.value()) * loss_nograd, - prior.distribution.log_prob(prior.value()) * loss_nograd - ], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[loss, part_loss], - expected_addl_terms=[ - likelihood.distribution.log_prob( - likelihood.value()) * loss_nograd, - (prior.distribution.log_prob(prior.value()) * - array_ops.stop_gradient(part_loss + loss)) - ], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[sum_loss * array_ops.ones_like(loss)], - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - sum_loss_nograd), prior.distribution.log_prob(prior.value()) * - sum_loss_nograd], - xs=[mu, sigma]) - - self._testSurrogateLoss( - session=sess, - losses=[loss, sum_loss * array_ops.ones_like(loss)], - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - array_ops.stop_gradient(loss + sum_loss)), - (prior.distribution.log_prob(prior.value()) * - array_ops.stop_gradient(loss + sum_loss))], - xs=[mu, sigma]) - - # These score functions should ignore prior and likelihood - self._testSurrogateLoss( - session=sess, - losses=[loss_nodeps], - expected_addl_terms=[(prior_2.distribution.log_prob(prior_2.value()) * - loss_nodeps_nograd)], - xs=[mu, sigma]) - - # These score functions should include all terms selectively - self._testSurrogateLoss( - session=sess, - losses=[loss, loss_nodeps], - # We can't guarantee ordering of output losses in this case. - expected_addl_terms=[( - likelihood.distribution.log_prob(likelihood.value()) * - loss_nograd), prior.distribution.log_prob(prior.value()) * - loss_nograd, - (prior_2.distribution.log_prob(prior_2.value()) * - loss_nodeps_nograd)], - xs=[mu, sigma]) - - def testNoSurrogateLoss(self): - with self.test_session(): - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - dt = st.StochasticTensor( - NormalNotParam( - loc=mu, scale=sigma), loss_fn=None) - self.assertEqual(None, dt.loss(constant_op.constant([2.0]))) - - def testExplicitStochasticTensors(self): - with self.test_session() as sess: - mu = constant_op.constant([0.0, 0.1, 0.2]) - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.SampleValue()): - dt1 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - dt2 = st.StochasticTensor(NormalNotParam(loc=mu, scale=sigma)) - loss = math_ops.square(array_ops.identity(dt1)) + 10. + dt2 - - sl_all = sg.surrogate_loss([loss]) - sl_dt1 = sg.surrogate_loss([loss], stochastic_tensors=[dt1]) - sl_dt2 = sg.surrogate_loss([loss], stochastic_tensors=[dt2]) - - dt1_term = dt1.distribution.log_prob(dt1) * loss - dt2_term = dt2.distribution.log_prob(dt2) * loss - - self.assertAllClose(*sess.run( - [sl_all, sum([loss, dt1_term, dt2_term])])) - self.assertAllClose(*sess.run([sl_dt1, sum([loss, dt1_term])])) - self.assertAllClose(*sess.run([sl_dt2, sum([loss, dt2_term])])) - - -class StochasticDependenciesMapTest(test.TestCase): - - def testBuildsMapOfUpstreamNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - dt2 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - out1 = dt1.value() + 1. - out2 = dt2.value() + 2. - x = out1 + out2 - y = out2 * 3. - dep_map = sg._stochastic_dependencies_map([x, y]) - self.assertEqual(dep_map[dt1], set([x])) - self.assertEqual(dep_map[dt2], set([x, y])) - - def testHandlesStackedStochasticNodes(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - out1 = dt1.value() + 1. - dt2 = st.StochasticTensor(distributions.Normal(loc=out1, scale=1.)) - x = dt2.value() + 2. - dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - y = dt3.value() * 3. - dep_map = sg._stochastic_dependencies_map([x, y]) - self.assertEqual(dep_map[dt1], set([x])) - self.assertEqual(dep_map[dt2], set([x])) - self.assertEqual(dep_map[dt3], set([y])) - - def testTraversesControlInputs(self): - dt1 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - logits = dt1.value() * 3. - dt2 = st.StochasticTensor(distributions.Bernoulli(logits=logits)) - dt3 = st.StochasticTensor(distributions.Normal(loc=0., scale=1.)) - x = dt3.value() - y = array_ops.ones((2, 2)) * 4. - z = array_ops.ones((2, 2)) * 3. - out = control_flow_ops.cond( - math_ops.cast(dt2, dtypes.bool), lambda: math_ops.add(x, y), - lambda: math_ops.square(z)) - out += 5. - dep_map = sg._stochastic_dependencies_map([out]) - self.assertEqual(dep_map[dt1], set([out])) - self.assertEqual(dep_map[dt2], set([out])) - self.assertEqual(dep_map[dt3], set([out])) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py deleted file mode 100644 index 6d0cff4678972719cb5c565bc409041e298beadb..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_tensor_test.py +++ /dev/null @@ -1,239 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import normal -from tensorflow.python.platform import test - -sge = stochastic_gradient_estimators -st = stochastic_tensor_impl - - -class StochasticTensorTest(test.TestCase): - - def testConstructionAndValue(self): - with self.test_session() as sess: - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - sigma2 = constant_op.constant([0.1, 0.2, 0.3]) - - prior_default = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior_default.value_type, st.SampleValue)) - prior_0 = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma), - dist_value_type=st.SampleValue()) - self.assertTrue(isinstance(prior_0.value_type, st.SampleValue)) - - with st.value_type(st.SampleValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior.value_type, st.SampleValue)) - likelihood = st.StochasticTensor( - normal.Normal(loc=prior, scale=sigma2)) - self.assertTrue(isinstance(likelihood.value_type, st.SampleValue)) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [prior_default, prior_0, prior, likelihood]) - - # Also works: tf.convert_to_tensor(prior) - prior_default = array_ops.identity(prior_default) - prior_0 = array_ops.identity(prior_0) - prior = array_ops.identity(prior) - likelihood = array_ops.identity(likelihood) - - # Mostly a smoke test for now... - prior_0_val, prior_val, prior_default_val, _ = sess.run( - [prior_0, prior, prior_default, likelihood]) - - self.assertEqual(prior_0_val.shape, prior_val.shape) - self.assertEqual(prior_default_val.shape, prior_val.shape) - # These are different random samples from the same distribution, - # so the values should differ. - self.assertGreater(np.abs(prior_0_val - prior_val).sum(), 1e-6) - self.assertGreater(np.abs(prior_default_val - prior_val).sum(), 1e-6) - - def testMeanValue(self): - with self.test_session() as sess: - mu = [0.0, -1.0, 1.0] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - - with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior.value_type, st.MeanValue)) - - prior_mean = prior.mean() - prior_value = prior.value() - - prior_mean_val, prior_value_val = sess.run([prior_mean, prior_value]) - self.assertAllEqual(prior_mean_val, mu) - self.assertAllEqual(prior_mean_val, prior_value_val) - - def testSampleValueScalar(self): - with self.test_session() as sess: - mu = [[0.0, -1.0, 1.0], [0.0, -1.0, 1.0]] - sigma = constant_op.constant([[1.1, 1.2, 1.3], [1.1, 1.2, 1.3]]) - - with st.value_type(st.SampleValue()): - prior_single = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - - prior_single_value = prior_single.value() - self.assertEqual(prior_single_value.get_shape(), (2, 3)) - - prior_single_value_val = sess.run([prior_single_value])[0] - self.assertEqual(prior_single_value_val.shape, (2, 3)) - - with st.value_type(st.SampleValue(1)): - prior_single = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - self.assertTrue(isinstance(prior_single.value_type, st.SampleValue)) - - prior_single_value = prior_single.value() - self.assertEqual(prior_single_value.get_shape(), (1, 2, 3)) - - prior_single_value_val = sess.run([prior_single_value])[0] - self.assertEqual(prior_single_value_val.shape, (1, 2, 3)) - - with st.value_type(st.SampleValue(2)): - prior_double = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma)) - - prior_double_value = prior_double.value() - self.assertEqual(prior_double_value.get_shape(), (2, 2, 3)) - - prior_double_value_val = sess.run([prior_double_value])[0] - self.assertEqual(prior_double_value_val.shape, (2, 2, 3)) - - def testDistributionEntropy(self): - with self.test_session() as sess: - mu = [0.0, -1.0, 1.0] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - with st.value_type(st.MeanValue()): - prior = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - entropy = prior.entropy() - deep_entropy = prior.distribution.entropy() - expected_deep_entropy = normal.Normal( - loc=mu, scale=sigma).entropy() - entropies = sess.run([entropy, deep_entropy, expected_deep_entropy]) - self.assertAllEqual(entropies[2], entropies[0]) - self.assertAllEqual(entropies[1], entropies[0]) - - def testSurrogateLoss(self): - with self.test_session(): - mu = [[3.0, -4.0, 5.0], [6.0, -7.0, 8.0]] - sigma = constant_op.constant(1.0) - - # With default - with st.value_type(st.MeanValue(stop_gradient=True)): - dt = st.StochasticTensor(normal.Normal(loc=mu, scale=sigma)) - loss = dt.loss([constant_op.constant(2.0)]) - self.assertTrue(loss is not None) - self.assertAllClose( - dt.distribution.log_prob(mu).eval() * 2.0, loss.eval()) - - # With passed-in loss_fn. - dt = st.StochasticTensor( - normal.Normal(loc=mu, scale=sigma), - dist_value_type=st.MeanValue(stop_gradient=True), - loss_fn=sge.get_score_function_with_constant_baseline( - baseline=constant_op.constant(8.0))) - loss = dt.loss([constant_op.constant(2.0)]) - self.assertTrue(loss is not None) - self.assertAllClose((dt.distribution.log_prob(mu) * (2.0 - 8.0)).eval(), - loss.eval()) - - -class ValueTypeTest(test.TestCase): - - def testValueType(self): - type_mean = st.MeanValue() - type_reshape = st.SampleValue() - type_full = st.SampleValue() - with st.value_type(type_mean): - self.assertEqual(st.get_current_value_type(), type_mean) - with st.value_type(type_reshape): - self.assertEqual(st.get_current_value_type(), type_reshape) - with st.value_type(type_full): - self.assertEqual(st.get_current_value_type(), type_full) - self.assertEqual(st.get_current_value_type(), type_mean) - with self.assertRaisesRegexp(ValueError, "No value type currently set"): - st.get_current_value_type() - - -class ObservedStochasticTensorTest(test.TestCase): - - def testConstructionAndValue(self): - with self.test_session() as sess: - mu = [0.0, 0.1, 0.2] - sigma = constant_op.constant([1.1, 1.2, 1.3]) - obs = array_ops.zeros((2, 3)) - z = st.ObservedStochasticTensor( - normal.Normal(loc=mu, scale=sigma), value=obs) - [obs_val, z_val] = sess.run([obs, z.value()]) - self.assertAllEqual(obs_val, z_val) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [z]) - - def testConstructionWithUnknownShapes(self): - mu = array_ops.placeholder(dtypes.float32) - sigma = array_ops.placeholder(dtypes.float32) - obs = array_ops.placeholder(dtypes.float32) - z = st.ObservedStochasticTensor( - normal.Normal(loc=mu, scale=sigma), value=obs) - - mu2 = array_ops.placeholder(dtypes.float32, shape=[None]) - sigma2 = array_ops.placeholder(dtypes.float32, shape=[None]) - obs2 = array_ops.placeholder(dtypes.float32, shape=[None, None]) - z2 = st.ObservedStochasticTensor( - normal.Normal(loc=mu2, scale=sigma2), value=obs2) - - coll = ops.get_collection(st.STOCHASTIC_TENSOR_COLLECTION) - self.assertEqual(coll, [z, z2]) - - def testConstructionErrors(self): - mu = [0., 0.] - sigma = [1., 1.] - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((3,))) - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((3, 1))) - self.assertRaises( - ValueError, - st.ObservedStochasticTensor, - normal.Normal(loc=mu, scale=sigma), - value=array_ops.zeros((1, 2), dtype=dtypes.int32)) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py deleted file mode 100644 index 9ee59a03ca76c6095e34b869d9b175e2c9223cd7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/stochastic_variables_test.py +++ /dev/null @@ -1,168 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for stochastic graphs.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np -from tensorflow.contrib import distributions -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import stochastic_variables -from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import variable_scope -from tensorflow.python.ops import variables -from tensorflow.python.platform import test - -sv = stochastic_variables -st = stochastic_tensor -vi = variational_inference_impl -dist = distributions - - -class StochasticVariablesTest(test.TestCase): - - def testStochasticVariables(self): - shape = (10, 20) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale)): - v = variable_scope.get_variable("sv", shape) - - self.assertTrue(isinstance(v, st.StochasticTensor)) - self.assertTrue(isinstance(v.distribution, dist.NormalWithSoftplusScale)) - - self.assertEqual( - {"stochastic_variables/sv_loc", "stochastic_variables/sv_scale"}, - set([v.op.name for v in variables.global_variables()])) - self.assertEqual( - set(variables.trainable_variables()), set(variables.global_variables())) - - v = ops.convert_to_tensor(v) - self.assertEqual(list(shape), v.get_shape().as_list()) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithConstantInitializer(self): - shape = (10, 20) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, - dist_kwargs={"validate_args": True}, - param_initializers={ - "loc": np.ones(shape) * 4., - "scale": np.ones(shape) * 2. - })): - v = variable_scope.get_variable("sv") - - for var in variables.global_variables(): - if "loc" in var.name: - mu_var = var - if "scale" in var.name: - sigma_var = var - - v = ops.convert_to_tensor(v) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(np.ones(shape) * 4., sess.run(mu_var)) - self.assertAllEqual(np.ones(shape) * 2., sess.run(sigma_var)) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithCallableInitializer(self): - shape = (10, 20) - - def sigma_init(shape, dtype, partition_info): - _ = partition_info - return array_ops.ones(shape, dtype=dtype) * 2. - - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, - dist_kwargs={"validate_args": True}, - param_initializers={ - "loc": np.ones( - shape, dtype=np.float32) * 4., - "scale": sigma_init - })): - v = variable_scope.get_variable("sv", shape) - - for var in variables.global_variables(): - if "loc" in var.name: - mu_var = var - if "scale" in var.name: - sigma_var = var - - v = ops.convert_to_tensor(v) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(np.ones(shape) * 4., sess.run(mu_var)) - self.assertAllEqual(np.ones(shape) * 2., sess.run(sigma_var)) - self.assertEqual(shape, sess.run(v).shape) - - def testStochasticVariablesWithPrior(self): - shape = (10, 20) - prior = dist.Normal(0., 1.) - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, prior=prior)): - w = variable_scope.get_variable("weights", shape) - - x = random_ops.random_uniform((8, 10)) - y = math_ops.matmul(x, w) - - prior_map = vi._find_variational_and_priors(y, None) - self.assertEqual(prior_map[w], prior) - elbo = vi.elbo(y, keep_batch_dim=False) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(elbo) - - def testStochasticVariablesWithCallablePriorInitializer(self): - - def prior_init(shape, dtype): - return dist.Normal( - array_ops.zeros(shape, dtype), array_ops.ones(shape, dtype)) - - with variable_scope.variable_scope( - "stochastic_variables", - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusScale, prior=prior_init)): - w = variable_scope.get_variable("weights", (10, 20)) - - x = random_ops.random_uniform((8, 10)) - y = math_ops.matmul(x, w) - - prior_map = vi._find_variational_and_priors(y, None) - self.assertTrue(isinstance(prior_map[w], dist.Normal)) - elbo = vi.elbo(y, keep_batch_dim=False) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - sess.run(elbo) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py b/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py deleted file mode 100644 index fff6b74b2efed27abd7b25cbe0e8e8b3904767e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/kernel_tests/variational_inference_test.py +++ /dev/null @@ -1,146 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for variational inference.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib import distributions as distributions_lib -from tensorflow.contrib import layers -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor -from tensorflow.contrib.bayesflow.python.ops import variational_inference_impl -from tensorflow.python.framework import constant_op -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variables -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.ops.distributions import normal -from tensorflow.python.platform import test - -st = stochastic_tensor -vi = variational_inference_impl -distributions = distributions_lib - - -class NormalNoEntropy(distributions.Normal): - - def entropy(self): - raise NotImplementedError("entropy not implemented") - - -# For mini-VAE -def inference_net(x, latent_size): - return layers.linear(x, latent_size) - - -def generative_net(z, data_size): - return layers.linear(z, data_size) - - -def mini_vae(): - x = [[-6., 3., 6.], [-8., 4., 8.]] - prior = distributions.Normal(loc=0., scale=1.) - variational = st.StochasticTensor( - distributions.Normal( - loc=inference_net(x, 1), scale=1.)) - vi.register_prior(variational, prior) - px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) - log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) - log_likelihood = array_ops.expand_dims(log_likelihood, -1) - return x, prior, variational, px, log_likelihood - - -class VariationalInferenceTest(test.TestCase): - - def testDefaultVariationalAndPrior(self): - _, prior, variational, _, log_likelihood = mini_vae() - elbo = vi.elbo(log_likelihood) - expected_elbo = log_likelihood - kullback_leibler.kl_divergence( - variational.distribution, prior) - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testExplicitVariationalAndPrior(self): - with self.test_session() as sess: - _, _, variational, _, log_likelihood = mini_vae() - prior = normal.Normal(loc=3., scale=2.) - elbo = vi.elbo( - log_likelihood, variational_with_prior={variational: prior}) - expected_elbo = log_likelihood - kullback_leibler.kl_divergence( - variational.distribution, prior) - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testExplicitForms(self): - _, prior, variational, _, log_likelihood = mini_vae() - - elbos = [] - forms = vi.ELBOForms - for form in [ - forms.default, forms.analytic_kl, forms.sample, forms.analytic_entropy - ]: - elbo = vi.elbo( - log_likelihood=log_likelihood, - variational_with_prior={variational: prior}, - form=form) - elbos.append(elbo) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - log_likelihood_shape = array_ops.shape(log_likelihood).eval() - for elbo in elbos: - elbo.eval() - elbo_shape = array_ops.shape(elbo).eval() - self.assertAllEqual(log_likelihood_shape, elbo_shape) - self.assertEqual(elbo.dtype, log_likelihood.dtype) - - def testDefaultsSampleKLWithoutAnalyticKLOrEntropy(self): - x = constant_op.constant([[-6., 3., 6.]]) - - prior = distributions.Bernoulli(0.5) - variational = st.StochasticTensor( - NormalNoEntropy( - loc=inference_net(x, 1), scale=1.)) - vi.register_prior(variational, prior) - px = distributions.Normal(loc=generative_net(variational, 3), scale=1.) - log_likelihood = math_ops.reduce_sum(px.log_prob(x), 1) - - # No analytic KL available between prior and variational distributions. - with self.assertRaisesRegexp(NotImplementedError, "No KL"): - distributions.kl_divergence(variational.distribution, prior) - - elbo = vi.elbo( - variational_with_prior={variational: prior}, - log_likelihood=log_likelihood) - expected_elbo = log_likelihood + prior.log_prob( - variational) - variational.distribution.log_prob(variational) - - with self.test_session() as sess: - sess.run(variables.global_variables_initializer()) - self.assertAllEqual(*sess.run([expected_elbo, elbo])) - - def testElboWithLogJoint(self): - with self.test_session() as sess: - _, prior, variational, _, log_likelihood = mini_vae() - log_joint = log_likelihood + prior.log_prob(variational) - elbo = vi.elbo_with_log_joint(log_joint) - sess.run(variables.global_variables_initializer()) - elbo.eval() - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py b/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py deleted file mode 100644 index 4a7679fb436b91c9ae70daf85552099e5b710cbc..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/entropy_impl.py +++ /dev/null @@ -1,386 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}. - -@@elbo_ratio -@@entropy_shannon -@@renyi_ratio -@@renyi_alpha -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import math - -from tensorflow.contrib.bayesflow.python.ops import monte_carlo_impl as monte_carlo -from tensorflow.contrib.bayesflow.python.ops import variational_inference -from tensorflow.contrib.bayesflow.python.ops.monte_carlo_impl import _get_samples as get_samples -from tensorflow.python.framework import ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging - - -# Make utility functions from monte_carlo available. -# pylint: disable=protected-access -_get_samples = get_samples -_logspace_mean = monte_carlo._logspace_mean -_sample_mean = monte_carlo._sample_mean - -# pylint: enable=protected-access - -__all__ = [ - 'elbo_ratio', - 'entropy_shannon', - 'renyi_ratio', - 'renyi_alpha', -] - -ELBOForms = variational_inference.ELBOForms # pylint: disable=invalid-name - - -def elbo_ratio(log_p, - q, - z=None, - n=None, - seed=None, - form=None, - name='elbo_ratio'): - r"""Estimate of the ratio appearing in the `ELBO` and `KL` divergence. - - With `p(z) := exp{log_p(z)}`, this `Op` returns an approximation of - - ``` - E_q[ Log[p(Z) / q(Z)] ] - ``` - - The term `E_q[ Log[p(Z)] ]` is always computed as a sample mean. - The term `E_q[ Log[q(z)] ]` can be computed with samples, or an exact formula - if `q.entropy()` is defined. This is controlled with the kwarg `form`. - - This log-ratio appears in different contexts: - - #### `KL[q || p]` - - If `log_p(z) = Log[p(z)]` for distribution `p`, this `Op` approximates - the negative Kullback-Leibler divergence. - - ``` - elbo_ratio(log_p, q, n=100) = -1 * KL[q || p], - KL[q || p] = E[ Log[q(Z)] - Log[p(Z)] ] - ``` - - Note that if `p` is a `Distribution`, then - `distributions.kl_divergence(q, p)` may be defined and available as an - exact result. - - #### ELBO - - If `log_p(z) = Log[p(z, x)]` is the log joint of a distribution `p`, this is - the Evidence Lower BOund (ELBO): - - ``` - ELBO ~= E[ Log[p(Z, x)] - Log[q(Z)] ] - = Log[p(x)] - KL[q || p] - <= Log[p(x)] - ``` - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - log_p: Callable mapping samples from `q` to `Tensors` with - shape broadcastable to `q.batch_shape`. - For example, `log_p` works "just like" `q.log_prob`. - q: `tf.contrib.distributions.Distribution`. - z: `Tensor` of samples from `q`, produced by `q.sample(n)` for some `n`. - n: Integer `Tensor`. Number of samples to generate if `z` is not provided. - seed: Python integer to seed the random number generator. - form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) - or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default` - (attempt analytic entropy, fallback on sample). - Default value is `ELBOForms.default`. - name: A name to give this `Op`. - - Returns: - Scalar `Tensor` holding sample mean KL divergence. `shape` is the batch - shape of `q`, and `dtype` is the same as `q`. - - Raises: - ValueError: If `form` is not handled by this function. - """ - form = ELBOForms.default if form is None else form - - with ops.name_scope(name, values=[n, z]): - z = _get_samples(q, z, n, seed) - - entropy = entropy_shannon(q, z=z, form=form) - - # If log_p(z) = Log[p(z)], cross entropy = -E_q[log(p(Z))] - negative_cross_entropy = _sample_mean(log_p(z)) - - return entropy + negative_cross_entropy - - -def entropy_shannon(p, - z=None, - n=None, - seed=None, - form=None, - name='entropy_shannon'): - r"""Monte Carlo or deterministic computation of Shannon's entropy. - - Depending on the kwarg `form`, this `Op` returns either the analytic entropy - of the distribution `p`, or the sampled entropy: - - ``` - -n^{-1} sum_{i=1}^n p.log_prob(z_i), where z_i ~ p, - \approx - E_p[ Log[p(Z)] ] - = Entropy[p] - ``` - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - p: `tf.contrib.distributions.Distribution` - z: `Tensor` of samples from `p`, produced by `p.sample(n)` for some `n`. - n: Integer `Tensor`. Number of samples to generate if `z` is not provided. - seed: Python integer to seed the random number generator. - form: Either `ELBOForms.analytic_entropy` (use formula for entropy of `q`) - or `ELBOForms.sample` (sample estimate of entropy), or `ELBOForms.default` - (attempt analytic entropy, fallback on sample). - Default value is `ELBOForms.default`. - name: A name to give this `Op`. - - Returns: - A `Tensor` with same `dtype` as `p`, and shape equal to `p.batch_shape`. - - Raises: - ValueError: If `form` not handled by this function. - ValueError: If `form` is `ELBOForms.analytic_entropy` and `n` was provided. - """ - form = ELBOForms.default if form is None else form - - if n is not None and form == ELBOForms.analytic_entropy: - raise ValueError('If form == ELBOForms.analytic_entropy, n must be None.') - - with ops.name_scope(name, values=[n, z]): - # Entropy: -E_p[log(p(Z))]. - entropy = None - - # Try analytic path - if form in [ELBOForms.default, ELBOForms.analytic_entropy]: - try: - entropy = p.entropy() - logging.info('Using analytic entropy(p:%s)', p) - except NotImplementedError as e: - if form == ELBOForms.analytic_entropy: - raise e - elif form != ELBOForms.sample: - raise ValueError('ELBOForm not handled by this function: %s' % form) - - # Sample path - if entropy is None: - logging.info('Using sampled entropy(p:%s)', p) - if z is None: - z = p.sample(n, seed=seed) - entropy = -monte_carlo.expectation(p.log_prob, z) - - return entropy - - -def renyi_ratio(log_p, q, alpha, z=None, n=None, seed=None, name='renyi_ratio'): - r"""Monte Carlo estimate of the ratio appearing in Renyi divergence. - - This can be used to compute the Renyi (alpha) divergence, or a log evidence - approximation based on Renyi divergence. - - #### Definition - - With `z_i` iid samples from `q`, and `exp{log_p(z)} = p(z)`, this `Op` returns - the (biased for finite `n`) estimate: - - ``` - (1 - alpha)^{-1} Log[ n^{-1} sum_{i=1}^n ( p(z_i) / q(z_i) )^{1 - alpha}, - \approx (1 - alpha)^{-1} Log[ E_q[ (p(Z) / q(Z))^{1 - alpha} ] ] - ``` - - This ratio appears in different contexts: - - #### Renyi divergence - - If `log_p(z) = Log[p(z)]` is the log prob of a distribution, and - `alpha > 0`, `alpha != 1`, this `Op` approximates `-1` times Renyi divergence: - - ``` - # Choose reasonably high n to limit bias, see below. - renyi_ratio(log_p, q, alpha, n=100) - \approx -1 * D_alpha[q || p], where - D_alpha[q || p] := (1 - alpha)^{-1} Log E_q[(p(Z) / q(Z))^{1 - alpha}] - ``` - - The Renyi (or "alpha") divergence is non-negative and equal to zero iff - `q = p`. Various limits of `alpha` lead to different special case results: - - ``` - alpha D_alpha[q || p] - ----- --------------- - --> 0 Log[ int_{q > 0} p(z) dz ] - = 0.5, -2 Log[1 - Hel^2[q || p]], (\propto squared Hellinger distance) - --> 1 KL[q || p] - = 2 Log[ 1 + chi^2[q || p] ], (\propto squared Chi-2 divergence) - --> infty Log[ max_z{q(z) / p(z)} ], (min description length principle). - ``` - - See "Renyi Divergence Variational Inference", by Li and Turner. - - #### Log evidence approximation - - If `log_p(z) = Log[p(z, x)]` is the log of the joint distribution `p`, this is - an alternative to the ELBO common in variational inference. - - ``` - L_alpha(q, p) = Log[p(x)] - D_alpha[q || p] - ``` - - If `q` and `p` have the same support, and `0 < a <= b < 1`, one can show - `ELBO <= D_b <= D_a <= Log[p(x)]`. Thus, this `Op` allows a smooth - interpolation between the ELBO and the true evidence. - - #### Stability notes - - Note that when `1 - alpha` is not small, the ratio `(p(z) / q(z))^{1 - alpha}` - is subject to underflow/overflow issues. For that reason, it is evaluated in - log-space after centering. Nonetheless, infinite/NaN results may occur. For - that reason, one may wish to shrink `alpha` gradually. See the `Op` - `renyi_alpha`. Using `float64` will also help. - - - #### Bias for finite sample size - - Due to nonlinearity of the logarithm, for random variables `{X_1,...,X_n}`, - `E[ Log[sum_{i=1}^n X_i] ] != Log[ E[sum_{i=1}^n X_i] ]`. As a result, this - estimate is biased for finite `n`. For `alpha < 1`, it is non-decreasing - with `n` (in expectation). For example, if `n = 1`, this estimator yields the - same result as `elbo_ratio`, and as `n` increases the expected value - of the estimator increases. - - #### Call signature - - User supplies either `Tensor` of samples `z`, or number of samples to draw `n` - - Args: - log_p: Callable mapping samples from `q` to `Tensors` with - shape broadcastable to `q.batch_shape`. - For example, `log_p` works "just like" `q.log_prob`. - q: `tf.contrib.distributions.Distribution`. - `float64` `dtype` recommended. - `log_p` and `q` should be supported on the same set. - alpha: `Tensor` with shape `q.batch_shape` and values not equal to 1. - z: `Tensor` of samples from `q`, produced by `q.sample` for some `n`. - n: Integer `Tensor`. The number of samples to use if `z` is not provided. - Note that this can be highly biased for small `n`, see docstring. - seed: Python integer to seed the random number generator. - name: A name to give this `Op`. - - Returns: - renyi_result: The scaled log of sample mean. `Tensor` with `shape` equal - to batch shape of `q`, and `dtype` = `q.dtype`. - """ - with ops.name_scope(name, values=[alpha, n, z]): - z = _get_samples(q, z, n, seed) - - # Evaluate sample mean in logspace. Note that _logspace_mean will compute - # (among other things) the mean of q.log_prob(z), which could also be - # obtained with q.entropy(). However, DON'T use analytic entropy, because - # that increases variance, and could result in NaN/Inf values of a sensitive - # term. - - # log_values - # = (1 - alpha) * ( Log p - Log q ) - log_values = (1. - alpha) * (log_p(z) - q.log_prob(z)) - - # log_mean_values - # = Log[ E[ values ] ] - # = Log[ E[ (p / q)^{1-alpha} ] ] - log_mean_values = _logspace_mean(log_values) - - return log_mean_values / (1. - alpha) - - -def renyi_alpha(step, - decay_time, - alpha_min, - alpha_max=0.99999, - name='renyi_alpha'): - r"""Exponentially decaying `Tensor` appropriate for Renyi ratios. - - When minimizing the Renyi divergence for `0 <= alpha < 1` (or maximizing the - Renyi equivalent of elbo) in high dimensions, it is not uncommon to experience - `NaN` and `inf` values when `alpha` is far from `1`. - - For that reason, it is often desirable to start the optimization with `alpha` - very close to 1, and reduce it to a final `alpha_min` according to some - schedule. The user may even want to optimize using `elbo_ratio` for - some fixed time before switching to Renyi based methods. - - This `Op` returns an `alpha` decaying exponentially with step: - - ``` - s(step) = (exp{step / decay_time} - 1) / (e - 1) - t(s) = max(0, min(s, 1)), (smooth growth from 0 to 1) - alpha(t) = (1 - t) alpha_min + t alpha_max - ``` - - Args: - step: Non-negative scalar `Tensor`. Typically the global step or an - offset version thereof. - decay_time: Positive scalar `Tensor`. - alpha_min: `float` or `double` `Tensor`. - The minimal, final value of `alpha`, achieved when `step >= decay_time` - alpha_max: `Tensor` of same `dtype` as `alpha_min`. - The maximal, beginning value of `alpha`, achieved when `step == 0` - name: A name to give this `Op`. - - Returns: - alpha: A `Tensor` of same `dtype` as `alpha_min`. - """ - with ops.name_scope(name, values=[step, decay_time, alpha_min, alpha_max]): - alpha_min = ops.convert_to_tensor(alpha_min, name='alpha_min') - dtype = alpha_min.dtype - - alpha_max = ops.convert_to_tensor(alpha_max, dtype=dtype, name='alpha_max') - decay_time = math_ops.cast(decay_time, dtype) - step = math_ops.cast(step, dtype) - - check_scalars = [ - check_ops.assert_rank(step, 0, message='step must be scalar'), - check_ops.assert_rank( - decay_time, 0, message='decay_time must be scalar'), - check_ops.assert_rank(alpha_min, 0, message='alpha_min must be scalar'), - check_ops.assert_rank(alpha_max, 0, message='alpha_max must be scalar'), - ] - check_sign = [ - check_ops.assert_non_negative( - step, message='step must be non-negative'), - check_ops.assert_positive( - decay_time, message='decay_time must be positive'), - ] - - with ops.control_dependencies(check_scalars + check_sign): - theta = (math_ops.exp(step / decay_time) - 1.) / (math.e - 1.) - theta = math_ops.minimum(math_ops.maximum(theta, 0.), 1.) - return alpha_max * (1. - theta) + alpha_min * theta diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc.py b/tensorflow/contrib/bayesflow/python/ops/hmc.py new file mode 100644 index 0000000000000000000000000000000000000000..977d42fc16bb91777a76c45ac24f3c5dc587f5fe --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/hmc.py @@ -0,0 +1,34 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +from tensorflow.contrib.bayesflow.python.ops.hmc_impl import * # pylint: disable=wildcard-import,unused-wildcard-import,g-importing-member +from tensorflow.python.util import all_util + +_allowed_symbols = [ + 'chain', + 'kernel', + 'leapfrog_integrator', + 'leapfrog_step', + 'ais_chain' +] + +all_util.remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..333dce929530adceb30dcb63653a5bd009c059e0 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/hmc_impl.py @@ -0,0 +1,635 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hamiltonian Monte Carlo, a gradient-based MCMC algorithm. + +@@chain +@@update +@@leapfrog_integrator +@@leapfrog_step +@@ais_chain +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import tf_logging as logging + +__all__ = [ + 'chain', + 'kernel', + 'leapfrog_integrator', + 'leapfrog_step', + 'ais_chain' +] + + +def _make_potential_and_grad(target_log_prob_fn): + def potential_and_grad(x): + log_prob_result = -target_log_prob_fn(x) + grad_result = gradients_impl.gradients(math_ops.reduce_sum(log_prob_result), + x)[0] + return log_prob_result, grad_result + return potential_and_grad + + +def chain(n_iterations, step_size, n_leapfrog_steps, initial_x, + target_log_prob_fn, event_dims=(), name=None): + """Runs multiple iterations of one or more Hamiltonian Monte Carlo chains. + + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) + algorithm that takes a series of gradient-informed steps to produce + a Metropolis proposal. This function samples from an HMC Markov + chain whose initial state is `initial_x` and whose stationary + distribution has log-density `target_log_prob_fn()`. + + This function can update multiple chains in parallel. It assumes + that all dimensions of `initial_x` not specified in `event_dims` are + independent, and should therefore be updated independently. The + output of `target_log_prob_fn()` should sum log-probabilities across + all event dimensions. Slices along dimensions not in `event_dims` + may have different target distributions; this is up to + `target_log_prob_fn()`. + + This function basically just wraps `hmc.kernel()` in a tf.scan() loop. + + Args: + n_iterations: Integer number of Markov chain updates to run. + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + initial_x: Tensor of initial state(s) of the Markov chain(s). + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + acceptance_probs: Tensor with the acceptance probabilities for each + iteration. Has shape matching `target_log_prob_fn(initial_x)`. + chain_states: Tensor with the state of the Markov chain at each iteration. + Has shape `[n_iterations, initial_x.shape[0],...,initial_x.shape[-1]`. + + #### Examples: + + ```python + # Sampling from a standard normal (note `log_joint()` is unnormalized): + def log_joint(x): + return tf.reduce_sum(-0.5 * tf.square(x)) + chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, + event_dims=[0]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ``` + + ```python + # Sampling from a diagonal-variance Gaussian: + variances = tf.linspace(1., 3., 10) + def log_joint(x): + return tf.reduce_sum(-0.5 / variances * tf.square(x)) + chain, acceptance_probs = hmc.chain(1000, 0.5, 2, tf.zeros(10), log_joint, + event_dims=[0]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ``` + + ```python + # Sampling from factor-analysis posteriors with known factors W: + # mu[i, j] ~ Normal(0, 1) + # x[i] ~ Normal(matmul(mu[i], W), I) + def log_joint(mu, x, W): + prior = -0.5 * tf.reduce_sum(tf.square(mu), 1) + x_mean = tf.matmul(mu, W) + likelihood = -0.5 * tf.reduce_sum(tf.square(x - x_mean), 1) + return prior + likelihood + chain, acceptance_probs = hmc.chain(1000, 0.1, 2, + tf.zeros([x.shape[0], W.shape[0]]), + lambda mu: log_joint(mu, x, W), + event_dims=[1]) + # Discard first half of chain as warmup/burn-in + warmed_up = chain[500:] + mean_est = tf.reduce_mean(warmed_up, 0) + var_est = tf.reduce_mean(tf.square(warmed_up), 0) - tf.square(mean_est) + ``` + + ```python + # Sampling from the posterior of a Bayesian regression model.: + + # Run 100 chains in parallel, each with a different initialization. + initial_beta = tf.random_normal([100, x.shape[1]]) + chain, acceptance_probs = hmc.chain(1000, 0.1, 10, initial_beta, + log_joint_partial, event_dims=[1]) + # Discard first halves of chains as warmup/burn-in + warmed_up = chain[500:] + # Averaging across samples within a chain and across chains + mean_est = tf.reduce_mean(warmed_up, [0, 1]) + var_est = tf.reduce_mean(tf.square(warmed_up), [0, 1]) - tf.square(mean_est) + ``` + """ + with ops.name_scope(name, 'hmc_chain', [n_iterations, step_size, + n_leapfrog_steps, initial_x]): + initial_x = ops.convert_to_tensor(initial_x, name='initial_x') + non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) + + def body(a, _): + updated_x, acceptance_probs, log_prob, grad = kernel( + step_size, n_leapfrog_steps, a[0], target_log_prob_fn, event_dims, + a[2], a[3]) + return updated_x, acceptance_probs, log_prob, grad + + potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + potential, grad = potential_and_grad(initial_x) + return functional_ops.scan(body, array_ops.zeros(n_iterations), + (initial_x, array_ops.zeros(non_event_shape), + -potential, -grad))[:2] + + +def ais_chain(n_iterations, step_size, n_leapfrog_steps, initial_x, + target_log_prob_fn, proposal_log_prob_fn, event_dims=(), + name=None): + """Runs annealed importance sampling (AIS) to estimate normalizing constants. + + This routine uses Hamiltonian Monte Carlo to sample from a series of + distributions that slowly interpolates between an initial "proposal" + distribution + + `exp(proposal_log_prob_fn(x) - proposal_log_normalizer)` + + and the target distribution + + `exp(target_log_prob_fn(x) - target_log_normalizer)`, + + accumulating importance weights along the way. The product of these + importance weights gives an unbiased estimate of the ratio of the + normalizing constants of the initial distribution and the target + distribution: + + E[exp(w)] = exp(target_log_normalizer - proposal_log_normalizer). + + Args: + n_iterations: Integer number of Markov chain updates to run. More + iterations means more expense, but smoother annealing between q + and p, which in turn means exponentially lower variance for the + normalizing constant estimator. + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + initial_x: Tensor of initial state(s) of the Markov chain(s). Must + be a sample from q, or results will be incorrect. + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + proposal_log_prob_fn: Python callable that returns the log density of the + initial distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + ais_weights: Tensor with the estimated weight(s). Has shape matching + `target_log_prob_fn(initial_x)`. + chain_states: Tensor with the state(s) of the Markov chain(s) the final + iteration. Has shape matching `initial_x`. + acceptance_probs: Tensor with the acceptance probabilities for the final + iteration. Has shape matching `target_log_prob_fn(initial_x)`. + + #### Examples: + + ```python + # Estimating the normalizing constant of a log-gamma distribution: + def proposal_log_prob(x): + # Standard normal log-probability. This is properly normalized. + return tf.reduce_sum(-0.5 * tf.square(x) - 0.5 * np.log(2 * np.pi), 1) + def target_log_prob(x): + # Unnormalized log-gamma(2, 3) distribution. + # True normalizer is (lgamma(2) - 2 * log(3)) * x.shape[1] + return tf.reduce_sum(2. * x - 3. * tf.exp(x), 1) + # Run 100 AIS chains in parallel + initial_x = tf.random_normal([100, 20]) + w, _, _ = hmc.ais_chain(1000, 0.2, 2, initial_x, target_log_prob, + proposal_log_prob, event_dims=[1]) + log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + ``` + + ```python + # Estimating the marginal likelihood of a Bayesian regression model: + base_measure = -0.5 * np.log(2 * np.pi) + def proposal_log_prob(x): + # Standard normal log-probability. This is properly normalized. + return tf.reduce_sum(-0.5 * tf.square(x) + base_measure, 1) + def regression_log_joint(beta, x, y): + # This function returns a vector whose ith element is log p(beta[i], y | x). + # Each row of beta corresponds to the state of an independent Markov chain. + log_prior = tf.reduce_sum(-0.5 * tf.square(beta) + base_measure, 1) + means = tf.matmul(beta, x, transpose_b=True) + log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means) + + base_measure, 1) + return log_prior + log_likelihood + def log_joint_partial(beta): + return regression_log_joint(beta, x, y) + # Run 100 AIS chains in parallel + initial_beta = tf.random_normal([100, x.shape[1]]) + w, beta_samples, _ = hmc.ais_chain(1000, 0.1, 2, initial_beta, + log_joint_partial, proposal_log_prob, + event_dims=[1]) + log_normalizer_estimate = tf.reduce_logsumexp(w) - np.log(100) + ``` + """ + with ops.name_scope(name, 'hmc_ais_chain', + [n_iterations, step_size, n_leapfrog_steps, initial_x]): + non_event_shape = array_ops.shape(target_log_prob_fn(initial_x)) + + beta_series = math_ops.linspace(0., 1., n_iterations+1)[1:] + def _body(a, beta): # pylint: disable=missing-docstring + def log_prob_beta(x): + return ((1 - beta) * proposal_log_prob_fn(x) + + beta * target_log_prob_fn(x)) + last_x = a[0] + w = a[2] + w += (1. / n_iterations) * (target_log_prob_fn(last_x) - + proposal_log_prob_fn(last_x)) + # TODO(b/66917083): There's an opportunity for gradient reuse here. + updated_x, acceptance_probs, _, _ = kernel(step_size, n_leapfrog_steps, + last_x, log_prob_beta, + event_dims) + return updated_x, acceptance_probs, w + + x, acceptance_probs, w = functional_ops.scan( + _body, beta_series, (initial_x, array_ops.zeros(non_event_shape), + array_ops.zeros(non_event_shape))) + return w[-1], x[-1], acceptance_probs[-1] + + +def kernel(step_size, n_leapfrog_steps, x, target_log_prob_fn, event_dims=(), + x_log_prob=None, x_grad=None, name=None): + """Runs one iteration of Hamiltonian Monte Carlo. + + Hamiltonian Monte Carlo (HMC) is a Markov chain Monte Carlo (MCMC) + algorithm that takes a series of gradient-informed steps to produce + a Metropolis proposal. This function applies one step of HMC to + randomly update the variable `x`. + + This function can update multiple chains in parallel. It assumes + that all dimensions of `x` not specified in `event_dims` are + independent, and should therefore be updated independently. The + output of `target_log_prob_fn()` should sum log-probabilities across + all event dimensions. Slices along dimensions not in `event_dims` + may have different target distributions; for example, if + `event_dims == (1,)`, then `x[0, :]` could have a different target + distribution from x[1, :]. This is up to `target_log_prob_fn()`. + + Args: + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `x`. Larger step sizes lead to faster progress, but + too-large step sizes make rejection exponentially more likely. + When possible, it's often helpful to match per-variable step + sizes to the standard deviations of the target distribution in + each variable. + n_leapfrog_steps: Integer number of steps to run the leapfrog + integrator for. Total progress per HMC step is roughly + proportional to step_size * n_leapfrog_steps. + x: Tensor containing the value(s) of the random variable(s) to update. + target_log_prob_fn: Python callable which takes an argument like `initial_x` + and returns its (possibly unnormalized) log-density under the target + distribution. + event_dims: List of dimensions that should not be treated as + independent. This allows for multiple chains to be run independently + in parallel. Default is (), i.e., all dimensions are independent. + x_log_prob (optional): Tensor containing the cached output of a previous + call to `target_log_prob_fn()` evaluated at `x` (such as that provided by + a previous call to `kernel()`). Providing `x_log_prob` and + `x_grad` saves one gradient computation per call to `kernel()`. + x_grad (optional): Tensor containing the cached gradient of + `target_log_prob_fn()` evaluated at `x` (such as that provided by + a previous call to `kernel()`). Providing `x_log_prob` and + `x_grad` saves one gradient computation per call to `kernel()`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + updated_x: The updated variable(s) x. Has shape matching `initial_x`. + acceptance_probs: Tensor with the acceptance probabilities for the final + iteration. This is useful for diagnosing step size problems etc. Has + shape matching `target_log_prob_fn(initial_x)`. + new_log_prob: The value of `target_log_prob_fn()` evaluated at `updated_x`. + new_grad: The value of the gradient of `target_log_prob_fn()` evaluated at + `updated_x`. + + #### Examples: + + ```python + # Tuning acceptance rates: + target_accept_rate = 0.631 + def target_log_prob(x): + # Standard normal + return tf.reduce_sum(-0.5 * tf.square(x)) + initial_x = tf.zeros([10]) + initial_log_prob = target_log_prob(initial_x) + initial_grad = tf.gradients(initial_log_prob, initial_x)[0] + # Algorithm state + x = tf.Variable(initial_x, name='x') + step_size = tf.Variable(1., name='step_size') + last_log_prob = tf.Variable(initial_log_prob, name='last_log_prob') + last_grad = tf.Variable(initial_grad, name='last_grad') + # Compute updates + new_x, acceptance_prob, log_prob, grad = hmc.kernel(step_size, 3, x, + target_log_prob, + event_dims=[0], + x_log_prob=last_log_prob) + x_update = tf.assign(x, new_x) + log_prob_update = tf.assign(last_log_prob, log_prob) + grad_update = tf.assign(last_grad, grad) + step_size_update = tf.assign(step_size, + tf.where(acceptance_prob > target_accept_rate, + step_size * 1.01, step_size / 1.01)) + adaptive_updates = [x_update, log_prob_update, grad_update, step_size_update] + sampling_updates = [x_update, log_prob_update, grad_update] + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + # Warm up the sampler and adapt the step size + for i in xrange(500): + sess.run(adaptive_updates) + # Collect samples without adapting step size + samples = np.zeros([500, 10]) + for i in xrange(500): + x_val, _ = sess.run([new_x, sampling_updates]) + samples[i] = x_val + ``` + + ```python + # Empirical-Bayes estimation of a hyperparameter by MCMC-EM: + + # Problem setup + N = 150 + D = 10 + x = np.random.randn(N, D).astype(np.float32) + true_sigma = 0.5 + true_beta = true_sigma * np.random.randn(D).astype(np.float32) + y = x.dot(true_beta) + np.random.randn(N).astype(np.float32) + + def log_prior(beta, log_sigma): + return tf.reduce_sum(-0.5 / tf.exp(2 * log_sigma) * tf.square(beta) - + log_sigma) + def regression_log_joint(beta, log_sigma, x, y): + # This function returns log p(beta | log_sigma) + log p(y | x, beta). + means = tf.matmul(tf.expand_dims(beta, 0), x, transpose_b=True) + means = tf.squeeze(means) + log_likelihood = tf.reduce_sum(-0.5 * tf.square(y - means)) + return log_prior(beta, log_sigma) + log_likelihood + def log_joint_partial(beta): + return regression_log_joint(beta, log_sigma, x, y) + # Our estimate of log(sigma) + log_sigma = tf.Variable(0., name='log_sigma') + # The state of the Markov chain + beta = tf.Variable(tf.random_normal([x.shape[1]]), name='beta') + new_beta, _, _, _ = hmc.kernel(0.1, 5, beta, log_joint_partial, + event_dims=[0]) + beta_update = tf.assign(beta, new_beta) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + with tf.control_dependencies([beta_update]): + log_sigma_update = optimizer.minimize(-log_prior(beta, log_sigma), + var_list=[log_sigma]) + + sess = tf.Session() + sess.run(tf.global_variables_initializer()) + log_sigma_history = np.zeros(1000) + for i in xrange(1000): + log_sigma_val, _ = sess.run([log_sigma, log_sigma_update]) + log_sigma_history[i] = log_sigma_val + # Should converge to something close to true_sigma + plt.plot(np.exp(log_sigma_history)) + ``` + """ + with ops.name_scope(name, 'hmc_kernel', [step_size, n_leapfrog_steps, x]): + potential_and_grad = _make_potential_and_grad(target_log_prob_fn) + + x_shape = array_ops.shape(x) + m = random_ops.random_normal(x_shape) + + kinetic_0 = 0.5 * math_ops.reduce_sum(math_ops.square(m), event_dims) + + if (x_log_prob is not None) and (x_grad is not None): + log_potential_0, grad_0 = -x_log_prob, -x_grad # pylint: disable=invalid-unary-operand-type + else: + if x_log_prob is not None: + logging.warn('x_log_prob was provided, but x_grad was not,' + ' so x_log_prob was not used.') + if x_grad is not None: + logging.warn('x_grad was provided, but x_log_prob was not,' + ' so x_grad was not used.') + log_potential_0, grad_0 = potential_and_grad(x) + + new_x, new_m, log_potential_1, grad_1 = leapfrog_integrator( + step_size, n_leapfrog_steps, x, m, potential_and_grad, grad_0) + + kinetic_1 = 0.5 * math_ops.reduce_sum(math_ops.square(new_m), event_dims) + + # TODO(mhoffman): It seems like there may be an opportunity for nans here. + # I'm delaying addressing this because we're going to refactor this part + # to use the more general Metropolis abstraction anyway. + acceptance_probs = math_ops.exp(math_ops.minimum(0., log_potential_0 - + log_potential_1 + + kinetic_0 - kinetic_1)) + accepted = math_ops.cast( + random_ops.random_uniform(array_ops.shape(acceptance_probs)) < + acceptance_probs, np.float32) + new_log_prob = (-log_potential_0 * (1. - accepted) - + log_potential_1 * accepted) + + # TODO(b/65738010): This should work, but it doesn't for now. + # reduced_shape = math_ops.reduced_shape(x_shape, event_dims) + reduced_shape = array_ops.shape(math_ops.reduce_sum(x, event_dims, + keep_dims=True)) + accepted = array_ops.reshape(accepted, reduced_shape) + new_x = x * (1. - accepted) + new_x * accepted + new_grad = -grad_0 * (1. - accepted) - grad_1 * accepted + + return new_x, acceptance_probs, new_log_prob, new_grad + + +def leapfrog_integrator(step_size, n_steps, initial_position, initial_momentum, + potential_and_grad, initial_grad, name=None): + """Applies `n_steps` steps of the leapfrog integrator. + + This just wraps `leapfrog_step()` in a `tf.while_loop()`, reusing + gradient computations where possible. + + Args: + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `initial_position`. Larger step sizes lead to faster progress, but + too-large step sizes lead to larger discretization error and + worse energy conservation. + n_steps: Number of steps to run the leapfrog integrator. + initial_position: Tensor containing the value(s) of the position variable(s) + to update. + initial_momentum: Tensor containing the value(s) of the momentum variable(s) + to update. + potential_and_grad: Python callable that takes a position tensor like + `initial_position` and returns the potential energy and its gradient at + that position. + initial_grad: Tensor with the value of the gradient of the potential energy + at `initial_position`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + updated_position: Updated value of the position. + updated_momentum: Updated value of the momentum. + new_potential: Potential energy of the new position. Has shape matching + `potential_and_grad(initial_position)`. + new_grad: Gradient from potential_and_grad() evaluated at the new position. + Has shape matching `initial_position`. + + Example: Simple quadratic potential. + ```python + def potential_and_grad(position): + return tf.reduce_sum(0.5 * tf.square(position)), position + position = tf.placeholder(np.float32) + momentum = tf.placeholder(np.float32) + potential, grad = potential_and_grad(position) + new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_integrator( + 0.1, 3, position, momentum, potential_and_grad, grad) + + sess = tf.Session() + position_val = np.random.randn(10) + momentum_val = np.random.randn(10) + potential_val, grad_val = sess.run([potential, grad], + {position: position_val}) + positions = np.zeros([100, 10]) + for i in xrange(100): + position_val, momentum_val, potential_val, grad_val = sess.run( + [new_position, new_momentum, new_potential, new_grad], + {position: position_val, momentum: momentum_val}) + positions[i] = position_val + # Should trace out sinusoidal dynamics. + plt.plot(positions[:, 0]) + ``` + """ + def leapfrog_wrapper(step_size, x, m, grad, l): + x, m, _, grad = leapfrog_step(step_size, x, m, potential_and_grad, grad) + return step_size, x, m, grad, l + 1 + + def counter_fn(a, b, c, d, counter): # pylint: disable=unused-argument + return counter < n_steps + + with ops.name_scope(name, 'leapfrog_integrator', + [step_size, n_steps, initial_position, initial_momentum, + initial_grad]): + _, new_x, new_m, new_grad, _ = control_flow_ops.while_loop( + counter_fn, leapfrog_wrapper, [step_size, initial_position, + initial_momentum, initial_grad, + array_ops.constant(0)], back_prop=False) + # We're counting on the runtime to eliminate this redundant computation. + new_potential, new_grad = potential_and_grad(new_x) + return new_x, new_m, new_potential, new_grad + + +def leapfrog_step(step_size, position, momentum, potential_and_grad, grad, + name=None): + """Applies one step of the leapfrog integrator. + + Assumes a simple quadratic kinetic energy function: 0.5 * ||momentum||^2. + + Args: + step_size: Scalar step size or array of step sizes for the + leapfrog integrator. Broadcasts to the shape of + `position`. Larger step sizes lead to faster progress, but + too-large step sizes lead to larger discretization error and + worse energy conservation. + position: Tensor containing the value(s) of the position variable(s) + to update. + momentum: Tensor containing the value(s) of the momentum variable(s) + to update. + potential_and_grad: Python callable that takes a position tensor like + `position` and returns the potential energy and its gradient at that + position. + grad: Tensor with the value of the gradient of the potential energy + at `position`. + name: Python `str` name prefixed to Ops created by this function. + + Returns: + updated_position: Updated value of the position. + updated_momentum: Updated value of the momentum. + new_potential: Potential energy of the new position. Has shape matching + `potential_and_grad(position)`. + new_grad: Gradient from potential_and_grad() evaluated at the new position. + Has shape matching `position`. + + Example: Simple quadratic potential. + ```python + def potential_and_grad(position): + # Simple quadratic potential + return tf.reduce_sum(0.5 * tf.square(position)), position + position = tf.placeholder(np.float32) + momentum = tf.placeholder(np.float32) + potential, grad = potential_and_grad(position) + new_position, new_momentum, new_potential, new_grad = hmc.leapfrog_step( + 0.1, position, momentum, potential_and_grad, grad) + + sess = tf.Session() + position_val = np.random.randn(10) + momentum_val = np.random.randn(10) + potential_val, grad_val = sess.run([potential, grad], + {position: position_val}) + positions = np.zeros([100, 10]) + for i in xrange(100): + position_val, momentum_val, potential_val, grad_val = sess.run( + [new_position, new_momentum, new_potential, new_grad], + {position: position_val, momentum: momentum_val}) + positions[i] = position_val + # Should trace out sinusoidal dynamics. + plt.plot(positions[:, 0]) + ``` + """ + with ops.name_scope(name, 'leapfrog_step', [step_size, position, momentum, + grad]): + momentum -= 0.5 * step_size * grad + position += step_size * momentum + potential, grad = potential_and_grad(position) + momentum -= 0.5 * step_size * grad + + return position, momentum, potential, grad diff --git a/tensorflow/contrib/bayesflow/python/ops/entropy.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py similarity index 82% rename from tensorflow/contrib/bayesflow/python/ops/entropy.py rename to tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py index a22e1c1d4e098439760267fca1374f986e45be8f..7bdeaa862d5bb64fa8940df453c7aa2b66023eda 100644 --- a/tensorflow/contrib/bayesflow/python/ops/entropy.py +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for Entropy Ops. See ${python/contrib.bayesflow.entropy}.""" +"""Functions to create a Markov Chain Monte Carlo Metropolis step.""" from __future__ import absolute_import from __future__ import division @@ -20,12 +20,14 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.entropy_impl import * +from tensorflow.contrib.bayesflow.python.ops.metropolis_hastings_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'ELBOForms', 'elbo_ratio', 'entropy_shannon', 'renyi_ratio', 'renyi_alpha' + 'evolve', + 'uniform_random_proposal', + 'normal_random_proposal', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..dc1ac68ce009fa46d6c05a3200a29d9fdf245707 --- /dev/null +++ b/tensorflow/contrib/bayesflow/python/ops/metropolis_hastings_impl.py @@ -0,0 +1,426 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions to create a Markov Chain Monte Carlo Metropolis step. + +@@evolve +@@uniform_random_proposal +@@normal_random_proposal +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops + +__all__ = [ + 'evolve', + 'uniform_random_proposal', + 'normal_random_proposal', +] + + +def _single_iteration(current_state, current_log_density, + log_unnormalized_prob_fn, proposal_fn, seed=None, + name='None'): + """Performs a single Metropolis-Hastings step. + + Args: + current_state: Float-like `Tensor` (i.e., `dtype` is either + `tf.float16`, `tf.float32` or `tf.float64`) of any shape that can + be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` + callables. + current_log_density: Float-like `Tensor` with `dtype` and shape equivalent + to `log_unnormalized_prob_fn(current_state)`, i.e., matching the result of + `log_unnormalized_prob_fn` invoked at `current_state`. + log_unnormalized_prob_fn: A Python callable evaluated at + `current_state` and returning a float-like `Tensor` of log target-density + up to a normalizing constant. In other words, + `log_unnormalized_prob_fn(x) = log(g(x))`, where + `target_density = g(x)/Z` for some constant `A`. The shape of the input + tensor is the same as the shape of the `current_state`. The shape of the + output tensor is either + (a). Same as the input shape if the density being sampled is one + dimensional, or + (b). If the density is defined for `events` of shape + `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of + shape `batch_shape + event_shape`, where `batch_shape = [B1, ..., Bb]` + and the result must be of shape [B1, ..., Bb]. For example, if the + distribution that is being sampled is a 10 dimensional normal, + then the input tensor may be of shape [100, 10] or [30, 20, 10]. The + last dimension will then be 'consumed' by `log_unnormalized_prob_fn` + and it should return tensors of shape [100] and [30, 20] respectively. + proposal_fn: A callable accepting a real valued `Tensor` of current sample + points and returning a tuple of two `Tensors`. The first element of the + pair is a `Tensor` containing the proposal state and should have + the same shape as the input `Tensor`. The second element of the pair gives + the log of the ratio of the probability of transitioning from the + proposal points to the input points and the probability of transitioning + from the input points to the proposal points. If the proposal is + symmetric (e.g., random walk, where the proposal is either + normal or uniform centered at `current_state`), i.e., + Probability(Proposal -> Current) = Probability(Current -> Proposal) + the second value should be set to `None` instead of explicitly supplying a + tensor of zeros. In addition to being convenient, this also leads to a + more efficient graph. + seed: `int` or None. The random seed for this `Op`. If `None`, no seed is + applied. + name: Python `str` name prefix for ops managed by this function. + + Returns: + next_state: `Tensor` with `dtype` and shape matching `current_state`. + Created by propagating the chain by one step, starting from + `current_state`. + next_log_density: `Tensor` with `dtype` and shape matching + `current_log_density`, which is equal to the value of the unnormalized + `log_unnormalized_prob_fn` computed at `next_state`. + log_accept_ratio: `Tensor` with `dtype` and shape matching + `current_log_density`. Stands for the log of Metropolis-Hastings + acceptance ratio used in generating the `next_state`. + """ + + with ops.name_scope(name, 'single_iteration', [current_state]): + # The proposed state and the log of the corresponding Hastings ratio. + proposal_state, log_transit_ratio = proposal_fn(current_state) + + # If the log ratio is None, assume that the transitions are symmetric, + # i.e., Prob(Current -> Proposed) = Prob(Proposed -> Current). + if log_transit_ratio is None: + log_transit_ratio = 0. + + # Log-density of the proposal state. + proposal_log_density = log_unnormalized_prob_fn(proposal_state) + + # Ops to compute the log of the acceptance ratio. Recall that the + # acceptance ratio is: [Prob(Proposed) / Prob(Current)] * + # [Prob(Proposed -> Current) / Prob(Current -> Proposed)]. The log of the + # second term is the log_transit_ratio. + with ops.name_scope('accept_reject'): + # The log of the acceptance ratio. + log_accept_ratio = (proposal_log_density - current_log_density + + log_transit_ratio) + + # A proposal is accepted or rejected depending on the acceptance ratio. + # If the acceptance ratio is greater than 1 then it is always accepted. + # If the acceptance ratio is less than 1 then the proposal is accepted + # with probability = acceptance ratio. As we are working in log space to + # prevent over/underflows, this logic is expressed in log terms below. + # If a proposal is accepted we place a True in the acceptance state + # tensor and if it is to be rejected we place a False. + # The log_draws below have to be compared to the log_accept_ratio so we + # make sure that they have the same data type. + log_draws = math_ops.log(random_ops.random_uniform( + array_ops.shape(current_log_density), seed=seed, + dtype=log_accept_ratio.dtype)) + is_proposal_accepted = log_draws < log_accept_ratio + + # The acceptance state decides which elements of the current state are to + # be replaced with the corresponding elements in the proposal state. + with ops.name_scope(name, 'metropolis_single_step', + [current_state, current_log_density]): + next_log_density = array_ops.where(is_proposal_accepted, + proposal_log_density, + current_log_density) + next_state = array_ops.where(is_proposal_accepted, proposal_state, + current_state) + + return next_state, next_log_density, log_accept_ratio + + +def evolve(initial_sample, + initial_log_density, + initial_log_accept_ratio, + log_unnormalized_prob_fn, + proposal_fn, + n_steps=1, + seed=None, + name=None): + """Performs `n_steps` of the Metropolis-Hastings update. + + Given a probability density function, `f(x)` and a proposal scheme which + generates new points from old, this `Op` returns a tensor + which may be used to generate approximate samples from the target distribution + using the Metropolis-Hastings algorithm. These samples are from a Markov chain + whose equilibrium distribution matches the target distribution. + + The probability distribution may have an unknown normalization constan. + We parameterize the probability density as follows: + ``` + f(x) = exp(L(x) + constant) + ``` + Here `L(x)` is any continuous function with an (possibly unknown but finite) + upper bound, i.e. there exists a number beta such that + `L(x)< beta < infinity` for all x. The constant is the normalization needed + to make `f(x)` a probability density (as opposed to just a finite measure). + + Although `initial_sample` can be arbitrary, a poor choice may result in a + slow-to-mix chain. In many cases the best choice is the one that maximizes + the target density, i.e., choose `initial_sample` such that + `f(initial_sample) >= f(x)` for all `x`. + + + If the support of the distribution is a strict subset of R^n (but of non zero + measure), then the unnormalized log-density `L(x)` should return `-infinity` + outside the support domain. This effectively forces the sampler to only + explore points in the regions of finite support. + + Usage: + This function is meant to be wrapped up with some of the common proposal + schemes (e.g. random walk, Langevin diffusion etc) to produce a more user + friendly interface. However, it may also be used to create bespoke samplers. + + The following example, demonstrates the use to generate a 1000 uniform random + walk Metropolis samplers run in parallel for the normal target distribution. + ```python + n = 3 # dimension of the problem + + # Generate 1000 initial values randomly. Each of these would be an + # independent starting point for a Markov chain. + state = tf.get_variable( + 'state',initializer=tf.random_normal([1000, n], mean=3.0, + dtype=tf.float64, seed=42)) + + # Computes the log(p(x)) for the unit normal density and ignores the + # normalization constant. + def log_density(x): + return - tf.reduce_sum(x * x, reduction_indices=-1) / 2.0 + + # Initial log-density value + state_log_density = tf.get_variable( + 'state_log_density', initializer=log_density(state.initialized_value())) + + # A variable to store the log_acceptance_ratio: + log_acceptance_ratio = tf.get_variable( + 'log_acceptance_ratio', initializer=tf.zeros([1000], dtype=tf.float64)) + + # Generates random proposals by moving each coordinate uniformly and + # independently in a box of size 2 centered around the current value. + # Returns the new point and also the log of the Hastings ratio (the + # ratio of the probability of going from the proposal to origin and the + # probability of the reverse transition). When this ratio is 1, the value + # may be omitted and replaced by None. + def random_proposal(x): + return (x + tf.random_uniform(tf.shape(x), minval=-1, maxval=1, + dtype=x.dtype, seed=12)), None + + # Create the op to propagate the chain for 100 steps. + stepper = mh.evolve( + state, state_log_density, log_acceptance_ratio, + log_density, random_proposal, n_steps=100, seed=123) + init = tf.initialize_all_variables() + with tf.Session() as sess: + sess.run(init) + # Run the chains for a total of 1000 steps and print out the mean across + # the chains every 100 iterations. + for n_iter in range(10): + # Executing the stepper advances the chain to the next state. + sess.run(stepper) + # Print out the current value of the mean(sample) for every dimension. + print(np.mean(sess.run(state), 0)) + # Estimated covariance matrix + samples = sess.run(state) + print('') + print(np.cov(samples, rowvar=False)) + ``` + + Args: + initial_sample: A float-like `tf.Variable` of any shape that can + be consumed by the `log_unnormalized_prob_fn` and `proposal_fn` + callables. + initial_log_density: Float-like `tf.Variable` with `dtype` and shape + equivalent to `log_unnormalized_prob_fn(initial_sample)`, i.e., matching + the result of `log_unnormalized_prob_fn` invoked at `current_state`. + initial_log_accept_ratio: A `tf.Variable` with `dtype` and shape matching + `initial_log_density`. Stands for the log of Metropolis-Hastings + acceptance ratio after propagating the chain for `n_steps`. + log_unnormalized_prob_fn: A Python callable evaluated at + `current_state` and returning a float-like `Tensor` of log target-density + up to a normalizing constant. In other words, + `log_unnormalized_prob_fn(x) = log(g(x))`, where + `target_density = g(x)/Z` for some constant `A`. The shape of the input + tensor is the same as the shape of the `current_state`. The shape of the + output tensor is either + (a). Same as the input shape if the density being sampled is one + dimensional, or + (b). If the density is defined for `events` of shape + `event_shape = [E1, E2, ... Ee]`, then the input tensor should be of + shape `batch_shape + event_shape`, here `batch_shape = [B1, ..., Bb]` + and the result must be of shape [B1, ..., Bb]. For example, if the + distribution that is being sampled is a 10 dimensional normal, + then the input tensor may be of shape [100, 10] or [30, 20, 10]. The + last dimension will then be 'consumed' by `log_unnormalized_prob_fn` + and it should return tensors of shape [100] and [30, 20] respectively. + proposal_fn: A callable accepting a real valued `Tensor` of current sample + points and returning a tuple of two `Tensors`. The first element of the + pair should be a `Tensor` containing the proposal state and should have + the same shape as the input `Tensor`. The second element of the pair gives + the log of the ratio of the probability of transitioning from the + proposal points to the input points and the probability of transitioning + from the input points to the proposal points. If the proposal is + symmetric, i.e. + Probability(Proposal -> Current) = Probability(Current -> Proposal) + the second value should be set to None instead of explicitly supplying a + tensor of zeros. In addition to being convenient, this also leads to a + more efficient graph. + n_steps: A positive `int` or a scalar `int32` tensor. Sets the number of + iterations of the chain. + seed: `int` or None. The random seed for this `Op`. If `None`, no seed is + applied. + name: A string that sets the name for this `Op`. + + Returns: + forward_step: an `Op` to step the Markov chain forward for `n_steps`. + """ + + with ops.name_scope(name, 'metropolis_hastings', [initial_sample]): + current_state = initial_sample + current_log_density = initial_log_density + log_accept_ratio = initial_log_accept_ratio + + # Stop condition for the while_loop + def stop_condition(i, _): + return i < n_steps + + def step(i, loop_vars): + """Wrap `_single_iteration` for `while_loop`.""" + state = loop_vars[0] + state_log_density = loop_vars[1] + return i + 1, list(_single_iteration(state, state_log_density, + log_unnormalized_prob_fn, + proposal_fn, seed=seed)) + + loop_vars = [current_state, current_log_density, log_accept_ratio] + # Build an `Op` to evolve the Markov chain for `n_steps` + (_, [end_state, end_log_density, end_log_acceptance]) = ( + control_flow_ops.while_loop( + stop_condition, step, + (0, loop_vars), + parallel_iterations=1, swap_memory=1)) + + forward_step = control_flow_ops.group( + state_ops.assign(current_log_density, end_log_density), + state_ops.assign(current_state, end_state), + state_ops.assign(log_accept_ratio, end_log_acceptance)) + + return forward_step + + +def uniform_random_proposal(step_size=1., + seed=None, + name=None): + """Returns a callable that adds a random uniform tensor to the input. + + This function returns a callable that accepts one `Tensor` argument of any + shape and a real data type (i.e. `tf.float32` or `tf.float64`). It adds a + sample from a random uniform distribution drawn from [-stepsize, stepsize] + to its input. It also returns the log of the ratio of the probability of + moving from the input point to the proposed point, but since this log ratio is + identically equal to 0 (because the probability of drawing a value `x` from + the symmetric uniform distribution is the same as the probability of drawing + `-x`), it simply returns None for the second element of the returned tuple. + + Args: + step_size: A positive `float` or a scalar tensor of real dtype + controlling the scale of the uniform distribution. + If step_size = a, then draws are made uniformly from [-a, a]. + seed: `int` or None. The random seed for this `Op`. If `None`, no seed is + applied. + name: A string that sets the name for this `Op`. + + Returns: + proposal_fn: A callable accepting one float-like `Tensor` and returning a + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. + """ + + with ops.name_scope(name, 'uniform_random_proposal', [step_size]): + def proposal_fn(input_state, name=None): + """Adds a uniform perturbation to the input state. + + Args: + input_state: A `Tensor` of any shape and real dtype. + name: A string that sets the name for this `Op`. + + Returns: + proposal_state: A float-like `Tensot` with `dtype` and shape matching + `input_state`. + log_transit_ratio: `None`. Proposal is symmetric. + """ + with ops.name_scope(name, 'proposer', [input_state]): + input_state = ops.convert_to_tensor(input_state, name='input_state') + return input_state + random_ops.random_uniform( + array_ops.shape(input_state), + minval=-step_size, + maxval=step_size, + seed=seed), None + return proposal_fn + + +def normal_random_proposal(scale=1., + seed=None, + name=None): + """Returns a callable that adds a random normal tensor to the input. + + This function returns a callable that accepts one `Tensor` argument of any + shape and a real data type (i.e. `tf.float32` or `tf.float64`). The callable + adds a sample from a normal distribution with the supplied standard deviation + and zero mean to its input argument (called the proposal point). + The callable returns a tuple with the proposal point as the first element. + The second element is identically `None`. It is included so the callable is + compatible with the expected signature of the proposal scheme argument in the + `metropolis_hastings` function. A value of `None` indicates that the + probability of going from the input point to the proposal point is equal to + the probability of going from the proposal point to the input point. + + Args: + scale: A positive `float` or a scalar tensor of any real dtype controlling + the scale of the normal distribution. + seed: `int` or None. The random seed for this `Op`. If `None`, no seed is + applied. + name: A string that sets the name for this `Op`. + + Returns: + proposal_fn: A callable accepting one float-like `Tensor` and returning a + 2-tuple. The first value in the tuple is a `Tensor` of the same shape and + dtype as the input argument and the second element of the tuple is None. + """ + + with ops.name_scope(name, 'normal_random_proposal', [scale]): + def proposal_fn(input_state, name=None): + """Adds a normal perturbation to the input state. + + Args: + input_state: A `Tensor` of any shape and real dtype. + name: A string that sets the name for this `Op`. + + Returns: + proposal_state: A float-like `Tensot` with `dtype` and shape matching + `input_state`. + log_transit_ratio: `None`. Proposal is symmetric. + """ + + with ops.name_scope(name, 'proposer', [input_state]): + input_state = ops.convert_to_tensor(input_state, name='input_state') + return input_state + random_ops.random_normal( + array_ops.shape(input_state), + mean=0., + stddev=scale, + seed=seed), None + return proposal_fn diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py deleted file mode 100644 index 695310837e0f6a58842f45c28608f12fbe162c6e..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_gradient_estimators.py +++ /dev/null @@ -1,317 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Stochastic gradient estimators. - -These functions are meant to be used in conjunction with `StochasticTensor` -(`loss_fn` parameter) and `surrogate_loss`. - -See Gradient Estimation Using Stochastic Computation Graphs -(http://arxiv.org/abs/1506.05254) by Schulman et al., eq. 1 and section 4, for -mathematical details. - -## Score function estimator - -The score function is an unbiased estimator of the gradient of `E_p(x)[f(x)]`, -where `f(x)` can be considered to be a "loss" term. It is computed as -`E_p(x)[f(x) grad(log p(x))]`. A constant `b`, referred to here as the -"baseline", can be subtracted from `f(x)` without affecting the expectation. The -term `(f(x) - b)` is referred to here as the "advantage". - -Note that the methods defined in this module actually compute the integrand of -the score function, such that when taking the gradient, the true score function -is computed. - -@@score_function -@@get_score_function_with_baseline -@@get_score_function_with_constant_baseline -@@get_score_function_with_advantage - -## Baseline functions - -Baselines reduce the variance of Monte Carlo estimate of an expectation. The -baseline for a stochastic node can be a function of all non-influenced nodes -(see section 4 of Schulman et al., linked above). Baselines are also known as -"control variates." - -In the context of a MC estimate of `E_p(x)[f(x) - b]`, baseline functions have -the signature `(st, fx) => Tensor`, where `st` is a `StochasticTensor` backed by -the distribution `p(x)` and `fx` is the influenced loss. - -@@get_mean_baseline - -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import variable_scope as vs -from tensorflow.python.training import training -from tensorflow.python.util.all_util import make_all - - -def score_function(stochastic_tensor, value, loss, baseline=None, - name="ScoreFunction"): - """Score function estimator. - - Computes the integrand of the score function with a baseline: - `p.log_prob(value) * (loss - baseline)`. - - It will add a `stop_gradient` to the advantage `(loss - baseline)`. - - Args: - stochastic_tensor: `StochasticTensor` p(x). - value: `Tensor` x. Samples from p(x). - loss: `Tensor`. - baseline: `Tensor` broadcastable to `loss`. - name: name to prepend ops with. - - Returns: - `Tensor` `p.log_prob(x) * (loss - b)`. Taking the gradient yields the score - function estimator. - """ - with ops.name_scope(name, values=[value, loss, baseline]): - value = ops.convert_to_tensor(value) - loss = ops.convert_to_tensor(loss) - if baseline is not None: - baseline = ops.convert_to_tensor(baseline) - advantage = loss - baseline - else: - advantage = loss - - advantage = array_ops.stop_gradient(advantage) - return stochastic_tensor.distribution.log_prob(value) * advantage - - -def get_score_function_with_advantage(advantage_fn=None, - name="ScoreFunctionWithAdvantage"): - """Score function estimator with advantage function. - - Args: - advantage_fn: callable that takes the `StochasticTensor` and the - downstream `loss` and returns a `Tensor` advantage - (e.g. `loss - baseline`). - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and uses the provided advantage. - """ - - def score_function_with_advantage(stochastic_tensor, value, loss): - with ops.name_scope(name, values=[value, loss]): - advantage = advantage_fn(stochastic_tensor, loss) - advantage = array_ops.stop_gradient(advantage) - return stochastic_tensor.distribution.log_prob(value) * advantage - - return score_function_with_advantage - - -def get_score_function_with_constant_baseline(baseline, name="ScoreFunction"): - """Score function estimator with constant baseline. - - Args: - baseline: `Tensor` to be subtracted from loss. - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and subtracts the provided - `baseline` from the `loss`. - """ - - def score_function_with_constant_baseline(stochastic_tensor, value, loss): - return score_function(stochastic_tensor, value, loss, baseline, name) - - return score_function_with_constant_baseline - - -def get_score_function_with_baseline(baseline_fn=None, name="ScoreFunction"): - """Score function estimator with baseline function. - - Args: - baseline_fn: callable that takes the `StochasticTensor` and the downstream - `loss` and returns a `Tensor` baseline to be subtracted from the `loss`. - If None, defaults to `get_mean_baseline`, which is an EMA of the loss. - name: name to prepend ops with. - - Returns: - Callable score function estimator that takes the `StochasticTensor`, the - sampled `value`, and the downstream `loss`, and subtracts the provided - `baseline` from the `loss`. - """ - if baseline_fn is None: - baseline_fn = get_mean_baseline() - - def score_function_with_baseline(stochastic_tensor, value, loss): - with ops.name_scope(name): - b = baseline_fn(stochastic_tensor, loss) - return score_function(stochastic_tensor, value, loss, b) - - return score_function_with_baseline - - -def get_mean_baseline(ema_decay=0.99, name=None): - """ExponentialMovingAverage baseline. - - Args: - ema_decay: decay rate for the ExponentialMovingAverage. - name: name for variable scope of the ExponentialMovingAverage. - - Returns: - Callable baseline function that takes the `StochasticTensor` (unused) and - the downstream `loss`, and returns an EMA of the loss. - """ - - def mean_baseline(_, loss): - with vs.variable_scope(name, default_name="MeanBaseline"): - reduced_loss = math_ops.reduce_mean(loss) - - ema = training.ExponentialMovingAverage(decay=ema_decay, zero_debias=True) - update_op = ema.apply([reduced_loss]) - - with ops.control_dependencies([update_op]): - # Using `identity` causes an op to be added in this context, which - # triggers the update. Removing the `identity` means nothing is updated. - baseline = array_ops.identity(ema.average(reduced_loss)) - - return baseline - - return mean_baseline - - -def get_vimco_advantage_fn(have_log_loss=False): - """VIMCO (Variational Inference for Monte Carlo Objectives) baseline. - - Implements VIMCO baseline from the article of the same name: - - https://arxiv.org/pdf/1602.06725v2.pdf - - Given a `loss` tensor (containing non-negative probabilities or ratios), - calculates the advantage VIMCO advantage via Eq. 9 of the above paper. - - The tensor `loss` should be shaped `[n, ...]`, with rank at least 1. Here, - the first axis is considered the single sampling dimension and `n` must - be at least 2. Specifically, the `StochasticTensor` is assumed to have - used the `SampleValue(n)` value type with `n > 1`. - - Args: - have_log_loss: Python `Boolean`. If `True`, the loss is assumed to be the - log loss. If `False` (the default), it is assumed to be a nonnegative - probability or probability ratio. - - Returns: - Callable baseline function that takes the `StochasticTensor` (unused) and - the downstream `loss`, and returns the VIMCO baseline for the loss. - """ - def vimco_advantage_fn(_, loss, name=None): - """Internal VIMCO function. - - Args: - _: ignored `StochasticTensor`. - loss: The loss `Tensor`. - name: Python string, the name scope to use. - - Returns: - The advantage `Tensor`. - """ - with ops.name_scope(name, "VIMCOAdvantage", values=[loss]): - loss = ops.convert_to_tensor(loss) - loss_shape = loss.get_shape() - loss_num_elements = loss_shape[0].value - n = math_ops.cast( - loss_num_elements or array_ops.shape(loss)[0], dtype=loss.dtype) - - if have_log_loss: - log_loss = loss - else: - log_loss = math_ops.log(loss) - - # Calculate L_hat, Eq. (4) -- stably - log_mean = math_ops.reduce_logsumexp(log_loss, [0]) - math_ops.log(n) - - # expand_dims: Expand shape [a, b, c] to [a, 1, b, c] - log_loss_expanded = array_ops.expand_dims(log_loss, [1]) - - # divide: log_loss_sub with shape [a, a, b, c], where - # - # log_loss_sub[i] = log_loss - log_loss[i] - # - # = [ log_loss[j] - log_loss[i] for rows j = 0 ... i - 1 ] - # [ zeros ] - # [ log_loss[j] - log_loss[i] for rows j = i + 1 ... a - 1 ] - # - log_loss_sub = log_loss - log_loss_expanded - - # reduce_sum: Sums each row across all the sub[i]'s; result is: - # reduce_sum[j] = (n - 1) * log_loss[j] - (sum_{i != j} loss[i]) - # divide by (n - 1) to get: - # geometric_reduction[j] = - # log_loss[j] - (sum_{i != j} log_loss[i]) / (n - 1) - geometric_reduction = math_ops.reduce_sum(log_loss_sub, [0]) / (n - 1) - - # subtract this from the original log_loss to get the baseline: - # geometric_mean[j] = exp((sum_{i != j} log_loss[i]) / (n - 1)) - log_geometric_mean = log_loss - geometric_reduction - - ## Equation (9) - - # Calculate sum_{i != j} loss[i] -- via exp(reduce_logsumexp(.)) - # reduce_logsumexp: log-sum-exp each row across all the - # -sub[i]'s, result is: - # - # exp(reduce_logsumexp[j]) = - # 1 + sum_{i != j} exp(log_loss[i] - log_loss[j]) - log_local_learning_reduction = math_ops.reduce_logsumexp( - -log_loss_sub, [0]) - - # convert local_learning_reduction to the sum-exp of the log-sum-exp - # (local_learning_reduction[j] - 1) * exp(log_loss[j]) - # = sum_{i != j} exp(log_loss[i]) - local_learning_log_sum = ( - _logexpm1(log_local_learning_reduction) + log_loss) - - # Add (logaddexp) the local learning signals (Eq. 9) - local_learning_signal = ( - math_ops.reduce_logsumexp( - array_ops.stack((local_learning_log_sum, log_geometric_mean)), - [0]) - - math_ops.log(n)) - - advantage = log_mean - local_learning_signal - - return advantage - - return vimco_advantage_fn - - -def _logexpm1(x): - """Stably calculate log(exp(x)-1).""" - with ops.name_scope("logsumexp1"): - eps = np.finfo(x.dtype.as_numpy_dtype).eps - # Choose a small offset that makes gradient calculations stable for - # float16, float32, and float64. - safe_log = lambda y: math_ops.log(y + eps / 1e8) # For gradient stability - return array_ops.where( - math_ops.abs(x) < eps, - safe_log(x) + x/2 + x*x/24, # small x approximation to log(expm1(x)) - safe_log(math_ops.exp(x) - 1)) - - -__all__ = make_all(__name__) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py deleted file mode 100644 index b2338bca8c94e0c7c44182f3f6bba7d7e79595e1..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph_impl.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes and helper functions for Stochastic Computation Graphs. - -## Stochastic Computation Graph Helper Functions - -@@surrogate_loss -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import collections - -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import tf_logging as logging - - -def _upstream_stochastic_nodes(tensors): - """Map tensors to the stochastic tensors upstream of them. - - Args: - tensors: a list of Tensors. - - Returns: - A dict that maps the tensors passed in to the `StochasticTensor` objects - upstream of them. - """ - reverse_map = _stochastic_dependencies_map(tensors) - upstream = collections.defaultdict(set) - for st, ts in reverse_map.items(): - for t in ts: - upstream[t].add(st) - return upstream - - -def _stochastic_dependencies_map(fixed_losses, stochastic_tensors=None): - """Map stochastic tensors to the fixed losses that depend on them. - - Args: - fixed_losses: a list of `Tensor`s. - stochastic_tensors: a list of `StochasticTensor`s to map to fixed losses. - If `None`, all `StochasticTensor`s in the graph will be used. - - Returns: - A dict `dependencies` that maps `StochasticTensor` objects to subsets of - `fixed_losses`. - - If `loss in dependencies[st]`, for some `loss` in `fixed_losses` then there - is a direct path from `st.value()` to `loss` in the graph. - """ - stoch_value_collection = stochastic_tensors or ops.get_collection( - stochastic_tensor_impl.STOCHASTIC_TENSOR_COLLECTION) - - if not stoch_value_collection: - return {} - - stoch_value_map = dict( - (node.value(), node) for node in stoch_value_collection) - - # Step backwards through the graph to see which surrogate losses correspond - # to which fixed_losses. - # - # TODO(ebrevdo): Ensure that fixed_losses and stochastic values are in the - # same frame. - stoch_dependencies_map = collections.defaultdict(set) - for loss in fixed_losses: - boundary = set([loss]) - while boundary: - edge = boundary.pop() - edge_stoch_node = stoch_value_map.get(edge, None) - if edge_stoch_node: - stoch_dependencies_map[edge_stoch_node].add(loss) - boundary.update(edge.op.inputs) - - return stoch_dependencies_map - - -def surrogate_loss(sample_losses, - stochastic_tensors=None, - name="SurrogateLoss"): - """Surrogate loss for stochastic graphs. - - This function will call `loss_fn` on each `StochasticTensor` - upstream of `sample_losses`, passing the losses that it influenced. - - Note that currently `surrogate_loss` does not work with `StochasticTensor`s - instantiated in `while_loop`s or other control structures. - - Args: - sample_losses: a list or tuple of final losses. Each loss should be per - example in the batch (and possibly per sample); that is, it should have - dimensionality of 1 or greater. All losses should have the same shape. - stochastic_tensors: a list of `StochasticTensor`s to add loss terms for. - If None, defaults to all `StochasticTensor`s in the graph upstream of - the `Tensor`s in `sample_losses`. - name: the name with which to prepend created ops. - - Returns: - `Tensor` loss, which is the sum of `sample_losses` and the - `loss_fn`s returned by the `StochasticTensor`s. - - Raises: - TypeError: if `sample_losses` is not a list or tuple, or if its elements - are not `Tensor`s. - ValueError: if any loss in `sample_losses` does not have dimensionality 1 - or greater. - """ - with ops.name_scope(name, values=sample_losses): - if not isinstance(sample_losses, (list, tuple)): - raise TypeError("sample_losses must be a list or tuple") - for loss in sample_losses: - if not isinstance(loss, ops.Tensor): - raise TypeError("loss is not a Tensor: %s" % loss) - ndims = loss.get_shape().ndims - if not (ndims is not None and ndims >= 1): - raise ValueError("loss must have dimensionality 1 or greater: %s" % - loss) - - stoch_dependencies_map = _stochastic_dependencies_map( - sample_losses, stochastic_tensors=stochastic_tensors) - if not stoch_dependencies_map: - logging.warn( - "No collection of Stochastic Tensors found for current graph.") - return math_ops.add_n(sample_losses) - - # Iterate through all of the stochastic dependencies, adding - # surrogate terms where necessary. - sample_losses = [ops.convert_to_tensor(loss) for loss in sample_losses] - loss_terms = sample_losses - for (stoch_node, dependent_losses) in stoch_dependencies_map.items(): - dependent_losses = list(dependent_losses) - - logging.info("Losses influenced by StochasticTensor %s: [%s]", - stoch_node.name, ", ".join( - [loss.name for loss in dependent_losses])) - - # Sum up the downstream losses for this ST - influenced_loss = _add_n_or_sum(dependent_losses) - - # Compute surrogate loss term - loss_term = stoch_node.loss(array_ops.stop_gradient(influenced_loss)) - if loss_term is not None: - loss_terms.append(loss_term) - - return _add_n_or_sum(loss_terms) - - -def _add_n_or_sum(terms): - # add_n works for Tensors of the same dtype and shape - shape = terms[0].get_shape() - dtype = terms[0].dtype - - if all(term.get_shape().is_fully_defined() and - term.get_shape().is_compatible_with(shape) and term.dtype == dtype - for term in terms): - return math_ops.add_n(terms) - else: - return sum(terms) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py deleted file mode 100644 index ce5fdd98c69ca6b3482bfafa8859accdf8a78749..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor_impl.py +++ /dev/null @@ -1,477 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Classes and helper functions for creating Stochastic Tensors. - -`StochasticTensor` objects wrap `Distribution` objects. Their -values may be samples from the underlying distribution, or the distribution -mean (as governed by `value_type`). These objects provide a `loss` -method for use when sampling from a non-reparameterized distribution. -The `loss`method is used in conjunction with `stochastic_graph.surrogate_loss` -to produce a single differentiable loss in stochastic graphs having -both continuous and discrete stochastic nodes. - -## Stochastic Tensor Classes - -@@BaseStochasticTensor -@@StochasticTensor - -## Stochastic Tensor Value Types - -@@MeanValue -@@SampleValue - -@@value_type -@@get_current_value_type -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc -import collections -import contextlib -import threading - -import six - -from tensorflow.contrib.bayesflow.python.ops import stochastic_gradient_estimators as sge -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops.distributions import distribution - -STOCHASTIC_TENSOR_COLLECTION = "_stochastic_tensor_collection_" - - -@six.add_metaclass(abc.ABCMeta) -class BaseStochasticTensor(object): - """Base Class for Tensor-like objects that emit stochastic values.""" - - def __init__(self): - # Add self to this graph's Stochsatic Tensor collection for - # purposes of later performing correct surrogate loss calculation. - ops.add_to_collection(STOCHASTIC_TENSOR_COLLECTION, self) - - @abc.abstractproperty - def name(self): - pass - - @abc.abstractproperty - def dtype(self): - pass - - @abc.abstractproperty - def graph(self): - pass - - @abc.abstractmethod - def value(self, name=None): - pass - - @abc.abstractmethod - def loss(self, sample_loss): - """Returns the term to add to the surrogate loss. - - This method is called by `surrogate_loss`. The input `sample_loss` should - have already had `stop_gradient` applied to it. This is because the - surrogate_loss usually provides a Monte Carlo sample term of the form - `differentiable_surrogate * sample_loss` where `sample_loss` is considered - constant with respect to the input for purposes of the gradient. - - Args: - sample_loss: `Tensor`, sample loss downstream of this `StochasticTensor`. - - Returns: - Either `None` or a `Tensor`. - """ - raise NotImplementedError("surrogate_loss not implemented") - - @staticmethod - def _tensor_conversion_function(v, dtype=None, name=None, as_ref=False): - _ = name - if dtype and not dtype.is_compatible_with(v.dtype): - raise ValueError( - "Incompatible type conversion requested to type '%s' for variable " - "of type '%s'" % (dtype.name, v.dtype.name)) - if as_ref: - raise ValueError("%s: Ref type is not supported." % v) - return v.value() - - -# pylint: disable=protected-access -ops.register_tensor_conversion_function( - BaseStochasticTensor, BaseStochasticTensor._tensor_conversion_function) - -# pylint: enable=protected-access - - -class _StochasticValueType(object): - """Interface for the ValueType classes. - - This is the base class for MeanValue, SampleValue, and their descendants. - """ - - def pushed_above(self, unused_value_type): - pass - - def popped_above(self, unused_value_type): - pass - - def declare_inputs(self, unused_stochastic_tensor, unused_inputs_dict): - pass - - @abc.abstractproperty - def stop_gradient(self): - """Whether the value should be wrapped in stop_gradient. - - StochasticTensors must respect this property. - """ - pass - - -class MeanValue(_StochasticValueType): - - def __init__(self, stop_gradient=False): - self._stop_gradient = stop_gradient - - @property - def stop_gradient(self): - return self._stop_gradient - - -class SampleValue(_StochasticValueType): - """Draw samples, possibly adding new outer dimensions along the way. - - This ValueType draws samples from StochasticTensors run within its - context, increasing the rank according to the requested shape. - - Examples: - - ```python - mu = tf.zeros((2,3)) - sigma = tf.ones((2, 3)) - with sg.value_type(sg.SampleValue()): - st = sg.StochasticTensor( - tf.contrib.distributions.Normal, mu=mu, sigma=sigma) - # draws 1 sample and does not reshape - assertEqual(st.value().get_shape(), (2, 3)) - ``` - - ```python - mu = tf.zeros((2,3)) - sigma = tf.ones((2, 3)) - with sg.value_type(sg.SampleValue(4)): - st = sg.StochasticTensor( - tf.contrib.distributions.Normal, mu=mu, sigma=sigma) - # draws 4 samples each with shape (2, 3) and concatenates - assertEqual(st.value().get_shape(), (4, 2, 3)) - ``` - """ - - def __init__(self, shape=(), stop_gradient=False): - """Sample according to shape. - - For the given StochasticTensor `st` using this value type, - the shape of `st.value()` will match that of - `st.distribution.sample(shape)`. - - Args: - shape: A shape tuple or int32 tensor. The sample shape. - Default is a scalar: take one sample and do not change the size. - stop_gradient: If `True`, StochasticTensors' values are wrapped in - `stop_gradient`, to avoid backpropagation through. - """ - self._shape = shape - self._stop_gradient = stop_gradient - - @property - def shape(self): - return self._shape - - @property - def stop_gradient(self): - return self._stop_gradient - - -# Keeps track of how a StochasticTensor's value should be accessed. -# Used by value_type and get_current_value_type below. -_STOCHASTIC_VALUE_STACK = collections.defaultdict(list) - - -@contextlib.contextmanager -def value_type(dist_value_type): - """Creates a value type context for any StochasticTensor created within. - - Typical usage: - - ``` - with sg.value_type(sg.MeanValue(stop_gradients=True)): - st = sg.StochasticTensor(tf.contrib.distributions.Normal, mu=mu, - sigma=sigma) - ``` - - In the example above, `st.value()` (or equivalently, `tf.identity(st)`) will - be the mean value of the Normal distribution, i.e., `mu` (possibly - broadcasted to the shape of `sigma`). Furthermore, because the `MeanValue` - was marked with `stop_gradients=True`, this value will have been wrapped - in a `stop_gradients` call to disable any possible backpropagation. - - Args: - dist_value_type: An instance of `MeanValue`, `SampleValue`, or - any other stochastic value type. - - Yields: - A context for `StochasticTensor` objects that controls the - value created when they are initialized. - - Raises: - TypeError: if `dist_value_type` is not an instance of a stochastic value - type. - """ - if not isinstance(dist_value_type, _StochasticValueType): - raise TypeError("dist_value_type must be a Distribution Value Type") - thread_id = threading.current_thread().ident - stack = _STOCHASTIC_VALUE_STACK[thread_id] - if stack: - stack[-1].pushed_above(dist_value_type) - stack.append(dist_value_type) - yield - stack.pop() - if stack: - stack[-1].popped_above(dist_value_type) - - -class NoValueTypeSetError(ValueError): - pass - - -def get_current_value_type(): - thread_id = threading.current_thread().ident - if not _STOCHASTIC_VALUE_STACK[thread_id]: - raise NoValueTypeSetError( - "No value type currently set for this thread (%s). Did you forget to " - "wrap 'with stochastic_graph.value_type(...)'?" % thread_id) - return _STOCHASTIC_VALUE_STACK[thread_id][-1] - - -class StochasticTensor(BaseStochasticTensor): - """StochasticTensor is a BaseStochasticTensor backed by a distribution.""" - - def __init__(self, - dist, - name="StochasticTensor", - dist_value_type=None, - loss_fn=sge.score_function): - """Construct a `StochasticTensor`. - - `StochasticTensor` is backed by the `dist` distribution and its `value` - method will return the same value each time it is called. What `value` is - returned is controlled by the `dist_value_type` (defaults to - `SampleValue`). - - Some distributions' sample functions are not differentiable (e.g. a sample - from a discrete distribution like a Bernoulli) and so to differentiate - wrt parameters upstream of the sample requires a gradient estimator like - the score function estimator. This is accomplished by passing a - differentiable `loss_fn` to the `StochasticTensor`, which - defaults to a function whose derivative is the score function estimator. - Calling `stochastic_graph.surrogate_loss(final_losses)` will call - `loss()` on every `StochasticTensor` upstream of final losses. - - `loss()` will return None for `StochasticTensor`s backed by - reparameterized distributions; it will also return None if the value type is - `MeanValueType` or if `loss_fn=None`. - - Args: - dist: an instance of `Distribution`. - name: a name for this `StochasticTensor` and its ops. - dist_value_type: a `_StochasticValueType`, which will determine what the - `value` of this `StochasticTensor` will be. If not provided, the - value type set with the `value_type` context manager will be used. - loss_fn: callable that takes - `(st, st.value(), influenced_loss)`, where - `st` is this `StochasticTensor`, and returns a `Tensor` loss. By - default, `loss_fn` is the `score_function`, or more precisely, the - integral of the score function, such that when the gradient is taken, - the score function results. See the `stochastic_gradient_estimators` - module for additional loss functions and baselines. - - Raises: - TypeError: if `dist` is not an instance of `Distribution`. - TypeError: if `loss_fn` is not `callable`. - """ - if not isinstance(dist, distribution.Distribution): - raise TypeError("dist must be an instance of Distribution") - if dist_value_type is None: - try: - self._value_type = get_current_value_type() - except NoValueTypeSetError: - self._value_type = SampleValue() - else: - # We want to enforce a value type here, but use the value_type() - # context manager to enforce some error checking. - with value_type(dist_value_type): - self._value_type = get_current_value_type() - - if loss_fn is not None and not callable(loss_fn): - raise TypeError("loss_fn must be callable") - self._loss_fn = loss_fn - - with ops.name_scope(name) as scope: - self._name = scope - self._dist = dist - self._value = self._create_value() - - super(StochasticTensor, self).__init__() - - @property - def value_type(self): - return self._value_type - - @property - def distribution(self): - return self._dist - - def _create_value(self): - """Create the value Tensor based on the value type, store as self._value.""" - - if isinstance(self._value_type, MeanValue): - value_tensor = self._dist.mean() - elif isinstance(self._value_type, SampleValue): - value_tensor = self._dist.sample(self._value_type.shape) - else: - raise TypeError("Unrecognized Distribution Value Type: %s", - self._value_type) - - if self._value_type.stop_gradient: - # stop_gradient is being enforced by the value type - return array_ops.stop_gradient(value_tensor) - - if isinstance(self._value_type, MeanValue): - return value_tensor # Using pathwise-derivative for this one. - if self._dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED: - return value_tensor # Using pathwise-derivative for this one. - else: - # Will have to perform some variant of score function - # estimation. Call stop_gradient on the sampler just in case we - # may accidentally leak some gradient from it. - return array_ops.stop_gradient(value_tensor) - - @property - def name(self): - return self._name - - @property - def graph(self): - return self._value.graph - - @property - def dtype(self): - return self._dist.dtype - - def entropy(self, name="entropy"): - return self._dist.entropy(name=name) - - def mean(self, name="mean"): - return self._dist.mean(name=name) - - def value(self, name="value"): - return self._value - - def loss(self, final_loss, name="Loss"): - # Return a loss based on final_loss and the distribution. Returns - # None if pathwise derivatives are supported, if the loss_fn - # was explicitly set to None, or if the value type is MeanValue. - if self._loss_fn is None: - return None - - if (self._dist.reparameterization_type == distribution.FULLY_REPARAMETERIZED - and not self._value_type.stop_gradient): - # Can perform pathwise-derivative on this one; no additional loss needed. - return None - - with ops.name_scope(self.name, values=[final_loss]): - with ops.name_scope(name): - if (self._value_type.stop_gradient or - isinstance(self._value_type, SampleValue)): - return self._loss_fn(self, self._value, final_loss) - elif isinstance(self._value_type, MeanValue): - return None # MeanValue generally provides its own gradient - else: - raise TypeError("Unrecognized Distribution Value Type: %s", - self._value_type) - - -class ObservedStochasticTensor(StochasticTensor): - """A StochasticTensor with an observed value.""" - - # pylint: disable=super-init-not-called - def __init__(self, dist, value, name=None): - """Construct an `ObservedStochasticTensor`. - - `ObservedStochasticTensor` is backed by distribution `dist` and uses the - provided value instead of using the current value type to draw a value from - the distribution. The provided value argument must be appropriately shaped - to have come from the distribution. - - Args: - dist: an instance of `Distribution`. - value: a Tensor containing the observed value - name: a name for this `ObservedStochasticTensor` and its ops. - - Raises: - TypeError: if `dist` is not an instance of `Distribution`. - ValueError: if `value` is not compatible with the distribution. - """ - if not isinstance(dist, distribution.Distribution): - raise TypeError("dist must be an instance of Distribution") - with ops.name_scope(name, "ObservedStochasticTensor", [value]) as scope: - self._name = scope - self._dist = dist - dist_shape = self._dist.batch_shape.concatenate( - self._dist.event_shape) - value = ops.convert_to_tensor(value) - value_shape = value.get_shape() - - if not value_shape.is_compatible_with(dist_shape): - if value_shape.ndims < dist_shape.ndims: - raise ValueError( - "Rank of observed value (%d) must be >= rank of a sample from the" - " distribution (%d)." % (value_shape.ndims, dist_shape.ndims)) - sample_shape = value_shape[(value_shape.ndims - dist_shape.ndims):] - if not sample_shape.is_compatible_with(dist_shape): - raise ValueError( - "Shape of observed value %s is incompatible with the shape of a " - "sample from the distribution %s." % (value_shape, dist_shape)) - if value.dtype != self._dist.dtype: - raise ValueError("Type of observed value (%s) does not match type of " - "distribution (%s)." % (value.dtype, self._dist.dtype)) - self._value = array_ops.identity(value) - # pylint: disable=non-parent-init-called - BaseStochasticTensor.__init__(self) - - def loss(self, final_loss, name=None): - return None - - -__all__ = [ - "BaseStochasticTensor", - "StochasticTensor", - "ObservedStochasticTensor", - "MeanValue", - "SampleValue", - "value_type", - "get_current_value_type", -] diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py b/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py deleted file mode 100644 index e16dbec11a188d42615c4e63d9f93925a6df30a3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_variables.py +++ /dev/null @@ -1,151 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Custom `get_variable` for stochastic variables. - -@@get_stochastic_variable -@@make_stochastic_variable_getter -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import functools - -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor as st -from tensorflow.contrib.bayesflow.python.ops import variational_inference as vi - - -def get_stochastic_variable(getter, - name, - shape=None, - dist_cls=None, - dist_kwargs=None, - param_initializers=None, - prior=None, - **kwargs): - """Custom variable getter for stochastic variables. - - `get_stochastic_variable` will create variables backing the parameters of a - distribution, defined by `dist_cls`, and return a `StochasticTensor` which - represents a sample from the backing distribution. - - Meant to be passed as the `custom_getter` to a `variable_scope`. Use - `make_stochastic_variable_getter` to partially apply distribution-related - args. - - Usage: - - ```python - - sv = tf.contrib.bayesflow.stochastic_variables - dist = tf.contrib.distributions - - with tf.variable_scope('my_scope', - custom_getter=sv.make_stochastic_variable_getter( - dist_cls=dist.NormalWithSoftplusSigma - param_initializers={ - "sigma": lambda shape, dtype, pi: ( - tf.constant(0.5, dtype=dtype, shape=shape)) - })): - v = tf.get_variable('my_var', (10, 20)) - ``` - - `v` is a `StochasticTensor`, which is a sample from a backing - `NormalWithSoftplusSigma` distribution. Underneath, 2 variables have been - created: `my_var_mu` and `my_var_sigma`. `my_var_sigma` has been appropriately - constrained to be positive by the `NormalWithSoftplusSigma` constructor, and - initialized to a value of 0.5, which results in a sigma of ~1 after the - softplus. The sample will have shape `(10, 20)`. - - Args: - getter: original variable getter. - name: prefix for variable(s) backing distribution parameters. - shape: shape of the sample from the distribution (i.e. shape of the - returned `StochasticTensor`). - dist_cls: subclass of `Distribution` that implements `param_shapes`. Should - accept unconstrained parameters (e.g. `NormalWithSoftplusSigma` accepts - real-valued `sigma` and constrains it to be positive with `softplus`). - dist_kwargs: `dict` of kwargs to be forwarded to `dist_cls`. - param_initializers: `dict` from parameter name to initializer (see - `get_variable` for initializer docs). Will override `initializer` in - `kwargs`. `param_initializers` may contain initializers for only some of - the parameters. Those parameters that do not contain entries will be - initialized by `kwargs['initializer']`, if provided; otherwise, the - default initialization of `getter` will be used. - prior: instance of `Distribution` or a callable - `(TensorShape, dtype) => Distribution`. If provided, will be registered - as the prior for the `StochasticTensor` using - `variational_inference.register_prior`. - **kwargs: kwargs forwarded to `getter`. - - Returns: - `StochasticTensor`, which represents a sample from the backing distribution. - """ - param_initializers = param_initializers or {} - param_shapes = {} - - if shape is not None: - param_shapes = dist_cls.param_static_shapes(shape) - - param_names = set(list(param_shapes.keys()) + list(param_initializers.keys())) - params = {} - for param_name in param_names: - # For each parameter, its param_initializer is used, if provided. Otherwise, - # kwargs['initializer'] is used. If neither were provided, the default - # variable initialization in getter will be used (i.e. getter will be passed - # initializer=None. - original_initializer = kwargs.pop('initializer', None) - param_initializer = param_initializers.get(param_name, None) - if param_initializer is None: - param_initializer = original_initializer - - if callable(param_initializer) or param_initializer is None: - param_shape = param_shapes.get(param_name, None) - else: - param_shape = None - - params[param_name] = getter( - name + '_' + param_name, - shape=param_shape, - initializer=param_initializer, - **kwargs) - - dist_kwargs = dist_kwargs or {} - dist_kwargs.update(params) - sample = st.StochasticTensor(dist_cls(**dist_kwargs)) - - if prior is not None: - if callable(prior): - sample_value = sample.value() - sample_value.get_shape().assert_is_fully_defined() - prior = prior(sample_value.get_shape(), sample_value.dtype) - - vi.register_prior(sample, prior) - - return sample - - -def make_stochastic_variable_getter(dist_cls, - dist_kwargs=None, - param_initializers=None, - prior=None): - """`get_stochastic_variable` with args partially applied.""" - return functools.partial( - get_stochastic_variable, - dist_cls=dist_cls, - dist_kwargs=dist_kwargs, - param_initializers=param_initializers, - prior=prior) diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py b/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py deleted file mode 100644 index 8d932a7c340e21da012d4ab93883735b13e01175..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/bayesflow/python/ops/variational_inference_impl.py +++ /dev/null @@ -1,327 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Variational inference. - -See the ${@python/contrib.bayesflow.variational_inference} guide. - -@@elbo -@@elbo_with_log_joint -@@ELBOForms -@@register_prior -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.bayesflow.python.ops import stochastic_graph_impl as sg -from tensorflow.contrib.bayesflow.python.ops import stochastic_tensor_impl as st -from tensorflow.python.framework import ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops.distributions import distribution -from tensorflow.python.ops.distributions import kullback_leibler -from tensorflow.python.platform import tf_logging as logging - -VI_PRIORS = "__vi_priors__" - - -def register_prior(variational, prior): - """Associate a variational `StochasticTensor` with a `Distribution` prior. - - This is a helper function used in conjunction with `elbo` that allows users - to specify the mapping between variational distributions and their priors - without having to pass in `variational_with_prior` explicitly. - - Args: - variational: `StochasticTensor` q(Z). Approximating distribution. - prior: `Distribution` p(Z). Prior distribution. - - Returns: - None - - Raises: - ValueError: if variational is not a `StochasticTensor` or `prior` is not - a `Distribution`. - """ - if not isinstance(variational, st.StochasticTensor): - raise TypeError("variational must be a StochasticTensor") - if not isinstance(prior, distribution.Distribution): - raise TypeError("prior must be a Distribution") - ops.add_to_collection(VI_PRIORS, (variational, prior)) - - -class _ELBOForm(object): - pass - - -class ELBOForms(object): - """Constants to control the `elbo` calculation. - - `analytic_kl` uses the analytic KL divergence between the - variational distribution(s) and the prior(s). - - `analytic_entropy` uses the analytic entropy of the variational - distribution(s). - - `sample` uses the sample KL or the sample entropy is the joint is provided. - - See `elbo` for what is used with `default`. - """ - default, analytic_kl, analytic_entropy, sample = (_ELBOForm() - for _ in range(4)) - - @staticmethod - def check_form(form): - if form not in { - ELBOForms.default, ELBOForms.analytic_kl, ELBOForms.analytic_entropy, - ELBOForms.sample - }: - raise TypeError("form must be an ELBOForms constant") - - -def elbo(log_likelihood, - variational_with_prior=None, - keep_batch_dim=True, - form=None, - name="ELBO"): - r"""Evidence Lower BOund. `log p(x) >= ELBO`. - - Optimization objective for inference of hidden variables by variational - inference. - - This function is meant to be used in conjunction with `StochasticTensor`. - The user should build out the inference network, using `StochasticTensor`s - as latent variables, and the generative network. `elbo` at minimum needs - `p(x|Z)` and assumes that all `StochasticTensor`s upstream of `p(x|Z)` are - the variational distributions. Use `register_prior` to register `Distribution` - priors for each `StochasticTensor`. Alternatively, pass in - `variational_with_prior` specifying all variational distributions and their - priors. - - Mathematical details: - - ``` - log p(x) = log \int p(x, Z) dZ - = log \int \frac {q(Z)p(x, Z)}{q(Z)} dZ - = log E_q[\frac {p(x, Z)}{q(Z)}] - >= E_q[log \frac {p(x, Z)}{q(Z)}] = L[q; p, x] # ELBO - - L[q; p, x] = E_q[log p(x|Z)p(Z)] - E_q[log q(Z)] - = E_q[log p(x|Z)p(Z)] + H[q] (1) - = E_q[log p(x|Z)] - KL(q || p) (2) - - H - Entropy - KL - Kullback-Leibler divergence - ``` - - See section 2.2 of Stochastic Variational Inference by Hoffman et al. for - more, including the ELBO's equivalence to minimizing `KL(q(Z)||p(Z|x))` - in the fully Bayesian setting. https://arxiv.org/pdf/1206.7051.pdf. - - `form` specifies which form of the ELBO is used. `form=ELBOForms.default` - tries, in order of preference: analytic KL, analytic entropy, sampling. - - Multiple entries in the `variational_with_prior` dict implies a factorization. - e.g. `q(Z) = q(z1)q(z2)q(z3)`. - - Args: - log_likelihood: `Tensor` log p(x|Z). - variational_with_prior: dict from `StochasticTensor` q(Z) to - `Distribution` p(Z). If `None`, defaults to all `StochasticTensor` - objects upstream of `log_likelihood` with priors registered with - `register_prior`. - keep_batch_dim: bool. Whether to keep the batch dimension when summing - entropy/KL term. When the sample is per data point, this should be True; - otherwise (e.g. in a Bayesian NN), this should be False. - form: ELBOForms constant. Controls how the ELBO is computed. Defaults to - ELBOForms.default. - name: name to prefix ops with. - - Returns: - `Tensor` ELBO of the same type and shape as `log_likelihood`. - - Raises: - TypeError: if variationals in `variational_with_prior` are not - `StochasticTensor`s or if priors are not `Distribution`s. - TypeError: if form is not a valid ELBOForms constant. - ValueError: if `variational_with_prior` is None and there are no - `StochasticTensor`s upstream of `log_likelihood`. - ValueError: if any variational does not have a prior passed or registered. - """ - if form is None: - form = ELBOForms.default - with ops.name_scope(name): - model = ops.convert_to_tensor(log_likelihood) - variational_with_prior = _find_variational_and_priors( - model, variational_with_prior) - return _elbo(form, log_likelihood, None, variational_with_prior, - keep_batch_dim) - - -def elbo_with_log_joint(log_joint, - variational=None, - keep_batch_dim=True, - form=None, - name="ELBO"): - """Evidence Lower BOund. `log p(x) >= ELBO`. - - This method is for models that have computed `p(x,Z)` instead of `p(x|Z)`. - See `elbo` for further details. - - Because only the joint is specified, analytic KL is not available. - - Args: - log_joint: `Tensor` log p(x, Z). - variational: list of `StochasticTensor` q(Z). If `None`, defaults to all - `StochasticTensor` objects upstream of `log_joint`. - keep_batch_dim: bool. Whether to keep the batch dimension when summing - entropy term. When the sample is per data point, this should be True; - otherwise (e.g. in a Bayesian NN), this should be False. - form: ELBOForms constant. Controls how the ELBO is computed. Defaults to - ELBOForms.default. - name: name to prefix ops with. - - Returns: - `Tensor` ELBO of the same type and shape as `log_joint`. - - Raises: - TypeError: if variationals in `variational` are not `StochasticTensor`s. - TypeError: if form is not a valid ELBOForms constant. - ValueError: if `variational` is None and there are no `StochasticTensor`s - upstream of `log_joint`. - ValueError: if form is ELBOForms.analytic_kl. - """ - if form is None: - form = ELBOForms.default - if form == ELBOForms.analytic_kl: - raise ValueError("ELBOForms.analytic_kl is not available when using " - "elbo_with_log_joint. Use elbo or a different form.") - - with ops.name_scope(name): - model = ops.convert_to_tensor(log_joint) - - variational_with_prior = None - if variational is not None: - variational_with_prior = dict(zip(variational, [None] * len(variational))) - variational_with_prior = _find_variational_and_priors( - model, variational_with_prior, require_prior=False) - return _elbo(form, None, log_joint, variational_with_prior, keep_batch_dim) - - -def _elbo(form, log_likelihood, log_joint, variational_with_prior, - keep_batch_dim): - """Internal implementation of ELBO. Users should use `elbo`. - - Args: - form: ELBOForms constant. Controls how the ELBO is computed. - log_likelihood: `Tensor` log p(x|Z). - log_joint: `Tensor` log p(x, Z). - variational_with_prior: `dict`, varational - distributions to prior distributions. - keep_batch_dim: bool. Whether to keep the batch dimension when reducing - the entropy/KL. - - Returns: - ELBO `Tensor` with same shape and dtype as `log_likelihood`/`log_joint`. - """ - ELBOForms.check_form(form) - - # Order of preference - # 1. Analytic KL: log_likelihood - KL(q||p) - # 2. Analytic entropy: log_likelihood + log p(Z) + H[q], or log_joint + H[q] - # 3. Sample: log_likelihood - (log q(Z) - log p(Z)) = - # log_likelihood + log p(Z) - log q(Z), or log_joint - q(Z) - - def _reduce(val): - if keep_batch_dim: - return val - else: - return math_ops.reduce_sum(val) - - kl_terms = [] - entropy_terms = [] - prior_terms = [] - for q, z, p in [(qz.distribution, qz.value(), pz) - for qz, pz in variational_with_prior.items()]: - # Analytic KL - kl = None - if log_joint is None and form in {ELBOForms.default, ELBOForms.analytic_kl}: - try: - kl = kullback_leibler.kl_divergence(q, p) - logging.info("Using analytic KL between q:%s, p:%s", q, p) - except NotImplementedError as e: - if form == ELBOForms.analytic_kl: - raise e - if kl is not None: - kl_terms.append(-1. * _reduce(kl)) - continue - - # Analytic entropy - entropy = None - if form in {ELBOForms.default, ELBOForms.analytic_entropy}: - try: - entropy = q.entropy() - logging.info("Using analytic entropy for q:%s", q) - except NotImplementedError as e: - if form == ELBOForms.analytic_entropy: - raise e - if entropy is not None: - entropy_terms.append(_reduce(entropy)) - if log_likelihood is not None: - prior = p.log_prob(z) - prior_terms.append(_reduce(prior)) - continue - - # Sample - if form in {ELBOForms.default, ELBOForms.sample}: - entropy = -q.log_prob(z) - entropy_terms.append(_reduce(entropy)) - if log_likelihood is not None: - prior = p.log_prob(z) - prior_terms.append(_reduce(prior)) - - first_term = log_joint if log_joint is not None else log_likelihood - return sum([first_term] + kl_terms + entropy_terms + prior_terms) - - -def _find_variational_and_priors(model, - variational_with_prior, - require_prior=True): - """Find upstream StochasticTensors and match with registered priors.""" - if variational_with_prior is None: - # pylint: disable=protected-access - upstreams = sg._upstream_stochastic_nodes([model]) - # pylint: enable=protected-access - upstreams = list(upstreams[model]) - if not upstreams: - raise ValueError("No upstream stochastic nodes found for tensor: %s", - model) - prior_map = dict(ops.get_collection(VI_PRIORS)) - variational_with_prior = {} - for q in upstreams: - if require_prior and (q not in prior_map or prior_map[q] is None): - raise ValueError("No prior specified for StochasticTensor: %s", q) - variational_with_prior[q] = prior_map.get(q) - - if not all( - [isinstance(q, st.StochasticTensor) for q in variational_with_prior]): - raise TypeError("variationals must be StochasticTensors") - if not all([ - p is None or isinstance(p, distribution.Distribution) - for p in variational_with_prior.values() - ]): - raise TypeError("priors must be Distribution objects") - - return variational_with_prior diff --git a/tensorflow/contrib/boosted_trees/BUILD b/tensorflow/contrib/boosted_trees/BUILD index 30f12d02f2a3a5774b6d3ddf24b0ff9d145cf56f..66a04d42e93331de74b6f3d41f83f071115c1097 100644 --- a/tensorflow/contrib/boosted_trees/BUILD +++ b/tensorflow/contrib/boosted_trees/BUILD @@ -28,7 +28,6 @@ package_group(name = "friends") cc_library( name = "boosted_trees_kernels", deps = [ - ":ensemble_optimizer_ops_kernels", ":model_ops_kernels", ":prediction_ops_kernels", ":quantile_ops_kernels", @@ -42,7 +41,6 @@ cc_library( cc_library( name = "boosted_trees_ops_op_lib", deps = [ - ":ensemble_optimizer_ops_op_lib", ":model_ops_op_lib", ":prediction_ops_op_lib", ":quantile_ops_op_lib", @@ -70,6 +68,10 @@ py_library( srcs = ["python/utils/losses.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:nn", ], ) @@ -79,12 +81,13 @@ py_test( size = "small", srcs = ["python/utils/losses_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", "//third_party/py/numpy", ], ) @@ -96,13 +99,30 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":gen_model_ops_py", "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", "//tensorflow/contrib/boosted_trees:boosted_trees_ops_py", "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/contrib/stateless", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/feature_column", ], ) @@ -112,60 +132,39 @@ py_test( srcs = ["python/training/functions/gbdt_batch_test.py"], srcs_version = "PY2AND3", tags = [ - "nomac", # b/63258195 "notsan", # b/62863147 ], deps = [ ":gbdt_batch", ":losses", - "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", - "//tensorflow/python:framework_test_lib", - "//third_party/py/numpy", - ], -) - -# Kernel tests - -py_test( - name = "ensemble_optimizer_ops_test", - size = "small", - srcs = ["python/kernel_tests/ensemble_optimizer_ops_test.py"], - srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], - deps = [ - ":ensemble_optimizer_ops_py", ":model_ops_py", + "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", - "//third_party/py/numpy", ], ) +# Kernel tests + py_test( name = "model_ops_test", size = "small", srcs = ["python/kernel_tests/model_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ - ":ensemble_optimizer_ops_py", ":model_ops_py", ":prediction_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -181,9 +180,6 @@ py_test( size = "small", srcs = ["python/kernel_tests/prediction_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":model_ops_py", ":prediction_ops_py", @@ -201,12 +197,12 @@ py_test( size = "small", srcs = ["python/kernel_tests/quantile_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":quantile_ops_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", @@ -238,9 +234,6 @@ py_test( size = "small", srcs = ["python/kernel_tests/stats_accumulator_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ ":stats_accumulator_ops_py", "//tensorflow/python:framework_ops", @@ -255,11 +248,7 @@ py_test( size = "small", srcs = ["python/kernel_tests/training_ops_test.py"], srcs_version = "PY2AND3", - tags = [ - "nomac", # b/63258195 - ], deps = [ - ":boosted_trees_ops_loader", ":model_ops_py", ":training_ops_py", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", @@ -294,9 +283,8 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:resources", + "//tensorflow/python:errors", + "//tensorflow/python:platform", ], ) @@ -304,7 +292,6 @@ py_library( name = "boosted_trees_ops_py", srcs_version = "PY2AND3", deps = [ - ":ensemble_optimizer_ops_py", ":model_ops_py", ":prediction_ops_py", ":quantile_ops_py", @@ -336,21 +323,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_model_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:training", ], ) tf_kernel_library( name = "model_ops_kernels", - srcs = [ - "kernels/model_ops.cc", - ], + srcs = ["kernels/model_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", @@ -361,14 +344,12 @@ tf_kernel_library( tf_custom_op_library( name = "python/ops/_boosted_trees_ops.so", srcs = [ - "kernels/ensemble_optimizer_ops.cc", "kernels/model_ops.cc", "kernels/prediction_ops.cc", "kernels/quantile_ops.cc", "kernels/split_handler_ops.cc", "kernels/stats_accumulator_ops.cc", "kernels/training_ops.cc", - "ops/ensemble_optimizer_ops.cc", "ops/model_ops.cc", "ops/prediction_ops.cc", "ops/quantile_ops.cc", @@ -416,23 +397,17 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_split_handler_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "split_handler_ops_kernels", - srcs = [ - "kernels/split_handler_ops.cc", - ], + srcs = ["kernels/split_handler_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:feature-column-handlers", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -464,25 +439,21 @@ tf_custom_op_py_library( deps = [ ":boosted_trees_ops_loader", ":gen_training_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", ], ) tf_kernel_library( name = "training_ops_kernels", - srcs = [ - "kernels/training_ops.cc", - ], + srcs = ["kernels/training_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", + "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", + "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_cc", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", + "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -519,9 +490,7 @@ tf_custom_op_py_library( tf_kernel_library( name = "prediction_ops_kernels", - srcs = [ - "kernels/prediction_ops.cc", - ], + srcs = ["kernels/prediction_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:example_partitioner", "//tensorflow/contrib/boosted_trees/lib:models", @@ -529,7 +498,6 @@ tf_kernel_library( "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", ], @@ -561,72 +529,22 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_quantile_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", ], ) tf_kernel_library( name = "quantile_ops_kernels", - srcs = [ - "kernels/quantile_ops.cc", - ], + srcs = ["kernels/quantile_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/lib:weighted_quantiles", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_cc", "//tensorflow/contrib/boosted_trees/resources:quantile_stream_resource", "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", - ], - alwayslink = 1, -) - -# Ensemble optimizer ops -tf_gen_op_libs( - op_lib_names = ["ensemble_optimizer_ops"], -) - -tf_gen_op_wrapper_py( - name = "gen_ensemble_optimizer_ops_py", - out = "python/ops/gen_ensemble_optimizer_ops.py", - deps = [ - ":ensemble_optimizer_ops_op_lib", - ], -) - -tf_custom_op_py_library( - name = "ensemble_optimizer_ops_py", - srcs = ["python/ops/ensemble_optimizer_ops.py"], - kernels = [ - ":ensemble_optimizer_ops_kernels", - ":ensemble_optimizer_ops_op_lib", - ], - srcs_version = "PY2AND3", - deps = [ - ":boosted_trees_ops_loader", - ":gen_ensemble_optimizer_ops_py", - "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework_for_generated_wrappers", - ], -) - -tf_kernel_library( - name = "ensemble_optimizer_ops_kernels", - srcs = [ - "kernels/ensemble_optimizer_ops.cc", - ], - deps = [ - "//tensorflow/contrib/boosted_trees/lib:utils", - "//tensorflow/contrib/boosted_trees/proto:learner_proto_cc", - "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_cc", - "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", - "//tensorflow/core:framework", - "//tensorflow/core:framework_headers_lib", - "//third_party/eigen3", ], alwayslink = 1, ) @@ -656,8 +574,6 @@ tf_custom_op_py_library( ":batch_ops_utils_py", ":boosted_trees_ops_loader", ":gen_stats_accumulator_ops_py_wrap", - "//tensorflow/contrib/util:util_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:resources", "//tensorflow/python:training", @@ -666,13 +582,10 @@ tf_custom_op_py_library( tf_kernel_library( name = "stats_accumulator_ops_kernels", - srcs = [ - "kernels/stats_accumulator_ops.cc", - ], + srcs = ["kernels/stats_accumulator_ops.cc"], deps = [ "//tensorflow/contrib/boosted_trees/lib:utils", "//tensorflow/contrib/boosted_trees/resources:stamped_resource", - "//tensorflow/core:framework", "//tensorflow/core:framework_headers_lib", ], alwayslink = 1, @@ -684,7 +597,12 @@ py_library( name = "boosted_trees_pip", deps = [ ":init_py", + "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees/estimator_batch:custom_export_strategy", "//tensorflow/contrib/boosted_trees/estimator_batch:init_py", + "//tensorflow/contrib/boosted_trees/estimator_batch:trainer_hooks", + "//tensorflow/contrib/boosted_trees/lib:categorical_split_handler", + "//tensorflow/contrib/boosted_trees/lib:ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", diff --git a/tensorflow/contrib/boosted_trees/README.md b/tensorflow/contrib/boosted_trees/README.md new file mode 100644 index 0000000000000000000000000000000000000000..7d30032e539fb16e27f48ea101094fa4d3e9171d --- /dev/null +++ b/tensorflow/contrib/boosted_trees/README.md @@ -0,0 +1,11 @@ +# TF Boosted Trees (TFBT) + +TF Boosted trees is an implementation of a gradient boosting algorithm with +trees used as weak learners. + +## Examples +Folder "examples" demonstrates how TFBT estimators can be used for various +problems. Namely, it contains: +* binary_mnist.py - an example on how to use TFBT for binary classification. +* mnist.py - a multiclass example. +* boston.py - a regression example. \ No newline at end of file diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD index f9e186788f6832b292a690d8d7b04e2f4edd584e..7792c7127c0285dc2eb5b213da054674f6a81d64 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/BUILD +++ b/tensorflow/contrib/boosted_trees/estimator_batch/BUILD @@ -27,13 +27,6 @@ py_library( "__init__.py", ], srcs_version = "PY2AND3", - deps = [ - "custom_export_strategy", - ":custom_loss_head", - ":estimator", - ":model", - ":trainer_hooks", - ], ) py_library( @@ -41,7 +34,12 @@ py_library( srcs = ["model.py"], srcs_version = "PY2AND3", deps = [ + ":trainer_hooks", "//tensorflow/contrib/boosted_trees:gbdt_batch", + "//tensorflow/contrib/boosted_trees:model_ops_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", ], ) @@ -51,6 +49,10 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", ], ) @@ -61,6 +63,15 @@ py_test( srcs_version = "PY2AND3", deps = [ ":trainer_hooks", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:session", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", ], ) @@ -69,6 +80,10 @@ py_library( srcs = ["custom_loss_head.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -82,6 +97,11 @@ py_library( "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/learn", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:session", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -92,8 +112,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":custom_export_strategy", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_extensions_py", - "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", + "//tensorflow/contrib/boosted_trees/proto:tree_config_proto_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", ], ) @@ -103,6 +124,8 @@ py_library( srcs_version = "PY2AND3", deps = [ ":model", - ":trainer_hooks", + "//tensorflow/contrib/boosted_trees:losses", + "//tensorflow/contrib/learn", + "//tensorflow/python:math_ops", ], ) diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py index 7773125c16772fe37369b11532c7f42df3ce166f..ef8dee91b6cc05c4c3dd5eb3c81de4fb65b473e3 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/custom_export_strategy.py @@ -96,7 +96,8 @@ def make_custom_export_strategy(name, def convert_to_universal_format(dtec, sorted_feature_names, num_dense, num_sparse_float, - num_sparse_int): + num_sparse_int, + feature_name_to_proto=None): """Convert GTFlow trees to universal format.""" del num_sparse_int # unused. model_and_features = generic_tree_model_pb2.ModelAndFeatures() @@ -104,7 +105,11 @@ def convert_to_universal_format(dtec, sorted_feature_names, # feature is processed before it's fed to the model (e.g. bucketing # information). As of now, this serves as a list of features the model uses. for feature_name in sorted_feature_names: - model_and_features.features[feature_name].SetInParent() + if not feature_name_to_proto: + model_and_features.features[feature_name].SetInParent() + else: + model_and_features.features[feature_name].CopyFrom( + feature_name_to_proto[feature_name]) model = model_and_features.model model.ensemble.summation_combination_technique.SetInParent() for tree_idx in range(len(dtec.trees)): @@ -144,6 +149,8 @@ def convert_to_universal_format(dtec, sorted_feature_names, split = gtflow_node.sparse_float_binary_split_default_left.split node.default_direction = ( generic_tree_model_pb2.BinaryNode.LEFT) + # TODO(nponomareva): adjust this id assignement when we allow multi- + # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test inequality_test.feature_id.id.value = sorted_feature_names[feature_id] @@ -154,6 +161,8 @@ def convert_to_universal_format(dtec, sorted_feature_names, split = gtflow_node.sparse_float_binary_split_default_right.split node.default_direction = ( generic_tree_model_pb2.BinaryNode.RIGHT) + # TODO(nponomareva): adjust this id assignement when we allow multi- + # column sparse tensors. feature_id = split.feature_column + num_dense inequality_test = node.inequality_left_child_test inequality_test.feature_id.id.value = sorted_feature_names[feature_id] diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py index f8028acbdb0be44b7fd81b96b04b6e24d9060aa6..01752416b347dd0a5e646283b6b5572592df4690 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator.py @@ -19,8 +19,10 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.boosted_trees.estimator_batch import model +from tensorflow.contrib.boosted_trees.python.utils import losses from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import head as head_lib +from tensorflow.python.ops import math_ops class GradientBoostedDecisionTreeClassifier(estimator.Estimator): @@ -65,10 +67,21 @@ class GradientBoostedDecisionTreeClassifier(estimator.Estimator): Raises: ValueError: If learner_config is not valid. """ + if n_classes > 2: + # For multi-class classification, use our loss implementation that + # supports second order derivative. + def loss_fn(labels, logits, weights=None): + result = losses.per_example_maxent_loss( + labels=labels, logits=logits, weights=weights, + num_classes=n_classes) + return math_ops.reduce_mean(result[0]) + else: + loss_fn = None head = head_lib.multi_class_head( n_classes=n_classes, weight_column_name=weight_column_name, - enable_centered_bias=False) + enable_centered_bias=False, + loss_fn=loss_fn) if learner_config.num_classes == 0: learner_config.num_classes = n_classes elif learner_config.num_classes != n_classes: diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/model.py b/tensorflow/contrib/boosted_trees/estimator_batch/model.py index 8cda5c8f2b14f2ec3cfe3702e38b81803dd075f7..c6455a7ea3d18eb358edee034cee58b2bed21024 100644 --- a/tensorflow/contrib/boosted_trees/estimator_batch/model.py +++ b/tensorflow/contrib/boosted_trees/estimator_batch/model.py @@ -93,7 +93,7 @@ def model_builder(features, labels, mode, params, config): learner_config=learner_config, feature_columns=feature_columns, logits_dimension=head.logits_dimension, - features=features) + features=training_features) with ops.name_scope("gbdt", "gbdt_optimizer"): predictions_dict = gbdt_model.predict(mode) logits = predictions_dict["predictions"] diff --git a/tensorflow/contrib/boosted_trees/examples/binary_mnist.py b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..47ee3d816f41e44f3a2458cf537d4f7dccf7b614 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/binary_mnist.py @@ -0,0 +1,169 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates multiclass MNIST TF Boosted trees example. + + This example demonstrates how to run experiments with TF Boosted Trees on + a binary dataset. We use digits 4 and 9 from the original MNIST dataset. + + Example Usage: + python tensorflow/contrib/boosted_trees/examples/binary_mnist.py \ + --output_dir="/tmp/binary_mnist" --depth=4 --learning_rate=0.3 \ + --batch_size=10761 --examples_per_layer=10761 --eval_batch_size=1030 \ + --num_eval_steps=1 --num_trees=10 --l2=1 --vmodule=training_ops=1 + + When training is done, accuracy on eval data is reported. Point tensorboard + to the directory for the run to see how the training progresses: + + tensorboard --logdir=/tmp/binary_mnist + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.learn import learn_runner + + +def get_input_fn(data, + batch_size, + capacity=10000, + min_after_dequeue=3000): + """Input function over MNIST data.""" + # Keep only 4 and 9 digits. + ids = np.where((data.labels == 4) | (data.labels == 9)) + images = data.images[ids] + labels = data.labels[ids] + # Make digit 4 label 1, 9 is 0. + labels = labels == 4 + + def _input_fn(): + """Prepare features and labels.""" + images_batch, labels_batch = tf.train.shuffle_batch( + tensors=[images, + labels.astype(np.int32)], + batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + num_threads=4) + features_map = {"images": images_batch} + return features_map, labels_batch + + return _input_fn + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer + learner_config.constraints.max_tree_depth = FLAGS.depth + + growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + learner_config.growing_mode = growing_mode + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees estimator that can take in custom loss. + estimator = GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + examples_per_layer=FLAGS.examples_per_layer, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + data = tf.contrib.learn.datasets.mnist.load_mnist() + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--eval_batch_size", + type=int, + default=1000, + help="Size of the batch for eval.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--examples_per_layer", + type=int, + default=1000, + help="Number of examples to accumulate stats for per layer.") + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/boston.py b/tensorflow/contrib/boosted_trees/examples/boston.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0a3c4912b82aba88e2f8f1b97a227c894ee2ae --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/boston.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates a regression on Boston housing data. + + This example demonstrates how to run experiments with TF Boosted Trees on + a regression dataset. We split all the data into 20% test and 80% train, + and are using l2 loss and l2 regularization. + + Example Usage: + + python tensorflow/contrib/boosted_trees/examples/boston.py \ + --batch_size=404 --output_dir="/tmp/boston" --depth=4 --learning_rate=0.1 \ + --num_eval_steps=1 --num_trees=500 --l2=4 \ + --vmodule=training_ops=1 + + When training is done, mean squared error on eval data is reported. + Point tensorboard to the directory for the run to see how the training + progresses: + + tensorboard --logdir=/tmp/boston + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeRegressor +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.layers.python.layers import feature_column +from tensorflow.contrib.learn import learn_runner + +_BOSTON_NUM_FEATURES = 13 + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir, feature_cols): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.regularization.l1 = 0.0 + # Set the regularization per instance in such a way that + # regularization for the full training data is equal to l2 flag. + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.batch_size + learner_config.constraints.max_tree_depth = FLAGS.depth + learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE + + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + # Create a TF Boosted trees regression estimator. + estimator = GradientBoostedDecisionTreeRegressor( + learner_config=learner_config, + # For the WHOLE_TREE strategy, set the examples_per_layer to be equal to + # batch size. + examples_per_layer=FLAGS.batch_size, + feature_columns=feature_cols, + label_dimension=1, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + (x_train, y_train), (x_test, + y_test) = tf.keras.datasets.boston_housing.load_data() + + train_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_train}, + y=y_train, + batch_size=FLAGS.batch_size, + num_epochs=None, + shuffle=True) + + eval_input_fn = tf.estimator.inputs.numpy_input_fn( + x={"x": x_test}, y=y_test, num_epochs=1, shuffle=False) + + feature_columns = [ + feature_column.real_valued_column("x", dimension=_BOSTON_NUM_FEATURES) + ] + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir, feature_columns), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/examples/mnist.py b/tensorflow/contrib/boosted_trees/examples/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..817c6eb3e1a79b38746418db9e5015e65ee70a50 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/examples/mnist.py @@ -0,0 +1,171 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Demonstrates multiclass MNIST TF Boosted trees example. + + This example demonstrates how to run experiments with TF Boosted Trees on + a MNIST dataset. We are using layer by layer boosting with diagonal hessian + strategy for multiclass handling, and cross entropy loss. + + Example Usage: + python tensorflow/contrib/boosted_trees/examples/mnist.py \ + --output_dir="/tmp/mnist" --depth=4 --learning_rate=0.3 --batch_size=60000 \ + --examples_per_layer=60000 --eval_batch_size=10000 --num_eval_steps=1 \ + --num_trees=10 --l2=1 --vmodule=training_ops=1 + + When training is done, accuracy on eval data is reported. Point tensorboard + to the directory for the run to see how the training progresses: + + tensorboard --logdir=/tmp/mnist + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.boosted_trees.estimator_batch.estimator import GradientBoostedDecisionTreeClassifier +from tensorflow.contrib.boosted_trees.proto import learner_pb2 +from tensorflow.contrib.learn import learn_runner + + +def get_input_fn(dataset_split, + batch_size, + capacity=10000, + min_after_dequeue=3000): + """Input function over MNIST data.""" + + def _input_fn(): + """Prepare features and labels.""" + images_batch, labels_batch = tf.train.shuffle_batch( + tensors=[dataset_split.images, + dataset_split.labels.astype(np.int32)], + batch_size=batch_size, + capacity=capacity, + min_after_dequeue=min_after_dequeue, + enqueue_many=True, + num_threads=4) + features_map = {"images": images_batch} + return features_map, labels_batch + + return _input_fn + + +# Main config - creates a TF Boosted Trees Estimator based on flags. +def _get_tfbt(output_dir): + """Configures TF Boosted Trees estimator based on flags.""" + learner_config = learner_pb2.LearnerConfig() + + num_classes = 10 + + learner_config.learning_rate_tuner.fixed.learning_rate = FLAGS.learning_rate + learner_config.num_classes = num_classes + learner_config.regularization.l1 = 0.0 + learner_config.regularization.l2 = FLAGS.l2 / FLAGS.examples_per_layer + learner_config.constraints.max_tree_depth = FLAGS.depth + + growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER + learner_config.growing_mode = growing_mode + run_config = tf.contrib.learn.RunConfig(save_checkpoints_secs=300) + + learner_config.multi_class_strategy = ( + learner_pb2.LearnerConfig.DIAGONAL_HESSIAN) + + # Create a TF Boosted trees estimator that can take in custom loss. + estimator = GradientBoostedDecisionTreeClassifier( + learner_config=learner_config, + n_classes=num_classes, + examples_per_layer=FLAGS.examples_per_layer, + model_dir=output_dir, + num_trees=FLAGS.num_trees, + center_bias=False, + config=run_config) + return estimator + + +def _make_experiment_fn(output_dir): + """Creates experiment for gradient boosted decision trees.""" + data = tf.contrib.learn.datasets.mnist.load_mnist() + train_input_fn = get_input_fn(data.train, FLAGS.batch_size) + eval_input_fn = get_input_fn(data.validation, FLAGS.eval_batch_size) + + return tf.contrib.learn.Experiment( + estimator=_get_tfbt(output_dir), + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + train_steps=None, + eval_steps=FLAGS.num_eval_steps, + eval_metrics=None) + + +def main(unused_argv): + learn_runner.run( + experiment_fn=_make_experiment_fn, + output_dir=FLAGS.output_dir, + schedule="train_and_evaluate") + + +if __name__ == "__main__": + tf.logging.set_verbosity(tf.logging.INFO) + parser = argparse.ArgumentParser() + # Define the list of flags that users can change. + parser.add_argument( + "--output_dir", + type=str, + required=True, + help="Choose the dir for the output.") + parser.add_argument( + "--batch_size", + type=int, + default=1000, + help="The batch size for reading data.") + parser.add_argument( + "--eval_batch_size", + type=int, + default=1000, + help="Size of the batch for eval.") + parser.add_argument( + "--num_eval_steps", + type=int, + default=1, + help="The number of steps to run evaluation for.") + # Flags for gradient boosted trees config. + parser.add_argument( + "--depth", type=int, default=4, help="Maximum depth of weak learners.") + parser.add_argument( + "--l2", type=float, default=1.0, help="l2 regularization per batch.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.1, + help="Learning rate (shrinkage weight) with which each new tree is added." + ) + parser.add_argument( + "--examples_per_layer", + type=int, + default=1000, + help="Number of examples to accumulate stats for per layer.") + parser.add_argument( + "--num_trees", + type=int, + default=None, + required=True, + help="Number of trees to grow before stopping.") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc b/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc deleted file mode 100644 index 5cde22901050eadb346d67d49968af925b596bac..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include -#include - -#include "tensorflow/contrib/boosted_trees/lib/utils/dropout_utils.h" -#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" -#include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/resource_mgr.h" -#include "tensorflow/core/framework/tensor.h" -#include "tensorflow/core/framework/tensor_shape.h" -#include "tensorflow/core/framework/types.h" -#include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/core/refcount.h" -#include "tensorflow/core/lib/core/status.h" -#include "tensorflow/core/platform/mutex.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/platform/types.h" - -namespace tensorflow { - -using boosted_trees::models::DecisionTreeEnsembleResource; -using boosted_trees::trees::DecisionTreeEnsembleConfig; -using boosted_trees::utils::DropoutUtils; -using errors::InvalidArgument; - -namespace { - -// Learning rate epsilon. -const float kLearningRateEps = 1e-8; - -} // namespace - -class AddTreesToEnsembleOp : public OpKernel { - public: - explicit AddTreesToEnsembleOp(OpKernelConstruction* const context) - : OpKernel(context) { - // Ensure feature importance lhs inputs are references. - OP_REQUIRES( - context, - IsRefType(context->input_type(kFeatureColumnUsageCountsHandleIdx)), - errors::InvalidArgument( - "Feature usage counts lhs input needs to be a ref type")); - OP_REQUIRES(context, - IsRefType(context->input_type(kFeatureColumnGainsHandleIdx)), - errors::InvalidArgument( - "Feature gains lhs input needs to be a ref type")); - } - - void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* decision_tree_ensemble_resource; - // Create a reference to the underlying resource using the handle. - OP_REQUIRES_OK( - context, LookupResource( - context, HandleFromInput(context, kTreeEnsembleHandleIdx), - &decision_tree_ensemble_resource)); - // Lock the resource since we're mutating it. - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); - // Remove the reference at the end of this scope. - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - - // Read feature importance info. - mutex_lock fc_usage_counts_mutex_lock( - *context->input_ref_mutex(kFeatureColumnUsageCountsHandleIdx)); - mutex_lock fc_gains_mutex_lock( - *context->input_ref_mutex(kFeatureColumnGainsHandleIdx)); - Tensor fc_usage_counts_lhs_t = - context->mutable_input(kFeatureColumnUsageCountsHandleIdx, true); - OP_REQUIRES(context, - TensorShapeUtils::IsVector(fc_usage_counts_lhs_t.shape()), - InvalidArgument("Feature usage counts should be a vector.")); - OP_REQUIRES(context, fc_usage_counts_lhs_t.IsInitialized(), - errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", - requested_input(kFeatureColumnUsageCountsHandleIdx))); - - Tensor fc_gains_lhs_t = - context->mutable_input(kFeatureColumnGainsHandleIdx, true); - OP_REQUIRES(context, TensorShapeUtils::IsVector(fc_gains_lhs_t.shape()), - InvalidArgument("Feature gains should be a vector.")); - OP_REQUIRES(context, fc_gains_lhs_t.IsInitialized(), - errors::FailedPrecondition( - "Attempting to use uninitialized variables: ", - requested_input(kFeatureColumnGainsHandleIdx))); - - const Tensor fc_usage_counts_rhs_t = - context->input(kFeatureColumnUsageCountsToAddIdx); - OP_REQUIRES( - context, - fc_usage_counts_lhs_t.shape().IsSameSize(fc_usage_counts_rhs_t.shape()), - errors::InvalidArgument( - "Shapes of both feature usage counts tensors should match.", - " lhs shape= ", fc_usage_counts_lhs_t.shape().DebugString(), - " rhs shape= ", fc_usage_counts_rhs_t.shape().DebugString())); - - const Tensor fc_gains_rhs_t = context->input(kFeatureColumnGainsToAddIdx); - OP_REQUIRES(context, - fc_gains_lhs_t.shape().IsSameSize(fc_gains_rhs_t.shape()), - errors::InvalidArgument( - "Shapes of both feature gains tensors should match.", - " lhs shape= ", fc_gains_lhs_t.shape().DebugString(), - " rhs shape= ", fc_gains_rhs_t.shape().DebugString())); - - // Read in info about trees that were dropped. - Tensor dropped_trees_info_t = context->input(kDropedTreesInfoTensorIdx); - OP_REQUIRES(context, - TensorShapeUtils::IsMatrix(dropped_trees_info_t.shape()), - InvalidArgument("Dropped trees info should be matrix.")); - - const auto& dropout_info = dropped_trees_info_t.matrix(); - - // Parse the passed in tree ensemble. - Tensor tree_ensemble_config_t = context->input(kEnsembleToAddTensorIdx); - OP_REQUIRES( - context, TensorShapeUtils::IsScalar(tree_ensemble_config_t.shape()), - errors::InvalidArgument("Tree ensemble config must be a scalar.")); - // Arena increase spatial locality which reduces the average latency to - // access memory, as working set of pages will be fewer. - // arena has type proto2::Arena*. - auto* arena = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble() - ->GetArena(); - DecisionTreeEnsembleConfig* ensemble_to_add = - protobuf::Arena::CreateMessage(arena); - OP_REQUIRES( - context, ParseProtoUnlimited(ensemble_to_add, - tree_ensemble_config_t.scalar()()), - errors::InvalidArgument("Unable to parse tree ensemble config.")); - - auto* mutable_ensemble = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - - // Read the learning_rate - Tensor learning_rate_t = context->input(kLearningRateTensorIdx); - OP_REQUIRES(context, TensorShapeUtils::IsScalar(learning_rate_t.shape()), - InvalidArgument("Learning rate should be a scalar.")); - - const float learning_rate = learning_rate_t.scalar()(); - if (learning_rate < kLearningRateEps) { - return; - } - // Prepare current weights vec. - std::vector current_weights; - current_weights.reserve(mutable_ensemble->tree_weights_size()); - for (const float weight : mutable_ensemble->tree_weights()) { - current_weights.push_back(weight); - } - const int32 num_dropped = dropped_trees_info_t.dim_size(1); - std::vector dropped_trees; - dropped_trees.reserve(num_dropped); - std::vector dropped_trees_original_weights; - dropped_trees_original_weights.reserve(num_dropped); - for (int i = 0; i < num_dropped; ++i) { - dropped_trees.push_back(dropout_info(0, i)); - dropped_trees_original_weights.push_back(dropout_info(1, i)); - } - - std::vector num_updates; - num_updates.reserve(mutable_ensemble->tree_metadata_size()); - - for (const auto& meta : mutable_ensemble->tree_metadata()) { - num_updates.push_back(meta.num_tree_weight_updates()); - } - - // If there was a dropout, come up with tree weights - const bool was_dropout = !dropped_trees.empty(); - if (was_dropout) { - // New tree/s will be added to the end of the ensemble's tree list. - const int32 new_tree_index = current_weights.size(); - DropoutUtils::GetTreesWeightsForAddingTrees( - dropped_trees, dropped_trees_original_weights, new_tree_index, - ensemble_to_add->trees_size(), ¤t_weights, &num_updates); - - // Update the weights of trees according to current weights; - for (int i = 0; i < mutable_ensemble->trees_size(); ++i) { - mutable_ensemble->set_tree_weights(i, current_weights[i]); - } - } - - // Add the trees from ensemble_to_add to the tree ensemble variable. - int i = mutable_ensemble->trees_size(); - for (auto& tree : *ensemble_to_add->mutable_trees()) { - (*mutable_ensemble->add_trees()).Swap(&tree); - - // New trees were updated only once. - auto* meta = mutable_ensemble->add_tree_metadata(); - meta->set_num_tree_weight_updates(1); - - // When we add complete trees to the ensemble in one step, each tree - // that's added is final. - meta->set_is_finalized(true); - - if (was_dropout) { - mutable_ensemble->add_tree_weights(current_weights[i++]); - } else { - mutable_ensemble->add_tree_weights(learning_rate); - } - } - - // Update the number of updates. - if (was_dropout) { - for (int i = 0; i < num_updates.size(); ++i) { - mutable_ensemble->mutable_tree_metadata(i)->set_num_tree_weight_updates( - num_updates[i]); - } - } - - // Update feature importance. - fc_usage_counts_lhs_t.vec() += fc_usage_counts_rhs_t.vec(); - fc_gains_lhs_t.vec() += learning_rate * fc_gains_rhs_t.vec(); - } - - private: - // Input tensor indices. - // Note that Op definition changes might cause input indices to need - // changing as well. - static const int kTreeEnsembleHandleIdx = 0; - static const int kEnsembleToAddTensorIdx = 1; - static const int kFeatureColumnUsageCountsHandleIdx = 2; - static const int kFeatureColumnUsageCountsToAddIdx = 3; - static const int kFeatureColumnGainsHandleIdx = 4; - static const int kFeatureColumnGainsToAddIdx = 5; - static const int kDropedTreesInfoTensorIdx = 6; - static const int kLearningRateTensorIdx = 7; -}; - -REGISTER_KERNEL_BUILDER(Name("AddTreesToEnsemble").Device(DEVICE_CPU), - AddTreesToEnsembleOp); - -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc index f4ad99f779e0d7fcf207934d77776548214371c1..4b5d5ba0de6c3995ee2da7a44ab0ba099cbf1b35 100644 --- a/tensorflow/contrib/boosted_trees/kernels/model_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/model_ops.cc @@ -15,7 +15,6 @@ #include #include "tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.h" -#include "tensorflow/contrib/boosted_trees/proto/tree_config.pb.h" #include "tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -46,9 +45,8 @@ class CreateTreeEnsembleVariableOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); auto* result = new boosted_trees::models::DecisionTreeEnsembleResource(); - result->set_stamp(stamp_token); - if (!ParseProtoUnlimited(result->mutable_decision_tree_ensemble(), - tree_ensemble_config_t->scalar()())) { + if (!result->InitFromSerialized(tree_ensemble_config_t->scalar()(), + stamp_token)) { result->Unref(); OP_REQUIRES(context, false, errors::InvalidArgument( "Unable to parse tree ensemble config.")); @@ -70,17 +68,15 @@ class TreeEnsembleStampTokenOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); - output_stamp_token_t->scalar()() = - decision_tree_ensemble_resource->stamp(); + output_stamp_token_t->scalar()() = ensemble_resource->stamp(); } }; @@ -91,23 +87,20 @@ class TreeEnsembleSerializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + tf_shared_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); Tensor* output_stamp_token_t = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape(), &output_stamp_token_t)); - output_stamp_token_t->scalar()() = - decision_tree_ensemble_resource->stamp(); + output_stamp_token_t->scalar()() = ensemble_resource->stamp(); Tensor* output_config_t = nullptr; OP_REQUIRES_OK( context, context->allocate_output(1, TensorShape(), &output_config_t)); output_config_t->scalar()() = - decision_tree_ensemble_resource->decision_tree_ensemble() - .SerializeAsString(); + ensemble_resource->SerializeAsString(); } }; @@ -118,12 +111,11 @@ class TreeEnsembleDeserializeOp : public OpKernel { : OpKernel(context) {} void Compute(OpKernelContext* context) override { - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + &ensemble_resource)); + mutex_lock l(*ensemble_resource->get_mutex()); + core::ScopedUnref unref_me(ensemble_resource); // Get the stamp token. const Tensor* stamp_token_t; @@ -135,13 +127,11 @@ class TreeEnsembleDeserializeOp : public OpKernel { OP_REQUIRES_OK(context, context->input("tree_ensemble_config", &tree_ensemble_config_t)); // Deallocate all the previous objects on the resource. - decision_tree_ensemble_resource->Reset(); - decision_tree_ensemble_resource->set_stamp(stamp_token); - boosted_trees::trees::DecisionTreeEnsembleConfig* config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); + ensemble_resource->Reset(); OP_REQUIRES( context, - ParseProtoUnlimited(config, tree_ensemble_config_t->scalar()()), + ensemble_resource->InitFromSerialized( + tree_ensemble_config_t->scalar()(), stamp_token), errors::InvalidArgument("Unable to parse tree ensemble config.")); } }; diff --git a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc index 8ffd7f120b49b09a49fde2ac7319f56a3f03459a..766982b4f2023310e6046619939f83bef63b0302 100644 --- a/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc @@ -59,8 +59,27 @@ const char* kApplyDropoutAttributeName = "apply_dropout"; const char* kApplyAveragingAttributeName = "apply_averaging"; const char* kDropoutInfoOutputTensorName = "drop_out_tree_indices_weights"; const char* kPredictionsTensorName = "predictions"; -const char* kNoDropoutPredictionsTensorName = "no_dropout_predictions"; + +void CalculateTreesToInclude( + const boosted_trees::trees::DecisionTreeEnsembleConfig& config, + const std::vector& trees_to_drop, const int32 num_trees, + const bool only_finalized, std::vector* trees_to_include) { + trees_to_include->reserve(num_trees - trees_to_drop.size()); + + int32 index = 0; + // This assumes that trees_to_drop is a sorted list of tree ids. + for (int32 tree = 0; tree < num_trees; ++tree) { + if ((!trees_to_drop.empty() && index < trees_to_drop.size() && + trees_to_drop[index] == tree) || + (only_finalized && config.tree_metadata_size() > 0 && + !config.tree_metadata(tree).is_finalized())) { + ++index; + continue; + } + trees_to_include->push_back(tree); + } } +} // namespace class GradientTreesPredictionOp : public OpKernel { public: @@ -128,7 +147,7 @@ class GradientTreesPredictionOp : public OpKernel { break; } case AveragingConfig::CONFIG_NOT_SET: { - QCHECK(false) << "We should never get here."; + LOG(QFATAL) << "We should never get here."; break; } } @@ -136,24 +155,23 @@ class GradientTreesPredictionOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* decision_tree_ensemble_resource; + DecisionTreeEnsembleResource* ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); + &ensemble_resource)); // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - DoCompute(context, decision_tree_ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); + DoCompute(context, ensemble_resource); } else { - DoCompute(context, decision_tree_ensemble_resource); + DoCompute(context, ensemble_resource); } } private: - void DoCompute( - OpKernelContext* context, - DecisionTreeEnsembleResource* decision_tree_ensemble_resource) { + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource) { // Read dense float features list; OpInputList dense_float_features_list; OP_REQUIRES_OK(context, TensorUtils::ReadDenseFloatFeatures( @@ -205,41 +223,35 @@ class GradientTreesPredictionOp : public OpKernel { // Do dropout if needed. if (apply_dropout_ && has_dropout_) { - // Read in seed + // Read in seed and cast to uint64. const Tensor* seed_t; OP_REQUIRES_OK(context, context->input(kSeedTensorName, &seed_t)); OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_t->shape()), errors::InvalidArgument("Seed must be a scalar.")); - - // Cast seed to uint64. const uint64 seed = seed_t->scalar()(); - std::vector weights; - for (const float weight : - decision_tree_ensemble_resource->decision_tree_ensemble() - .tree_weights()) { - weights.push_back(weight); - } - std::unordered_set trees_not_to_drop; if (center_bias_) { trees_not_to_drop.insert(0); } - if (decision_tree_ensemble_resource->decision_tree_ensemble() - .has_growing_metadata()) { + if (ensemble_resource->decision_tree_ensemble().has_growing_metadata()) { // We are in batch mode, the last tree is the tree that is being built, // we can't drop it during dropout. - const int32 current_tree = - decision_tree_ensemble_resource->decision_tree_ensemble() - .trees_size() - - 1; - trees_not_to_drop.insert(current_tree); + trees_not_to_drop.insert(ensemble_resource->num_trees() - 1); } + const std::vector weights = ensemble_resource->GetTreeWeights(); OP_REQUIRES_OK(context, DropoutUtils::DropOutTrees( seed, dropout_config_, trees_not_to_drop, weights, &dropped_trees, &original_weights)); } + // Prepare the list of trees to include in the prediction. + std::vector trees_to_include; + CalculateTreesToInclude( + ensemble_resource->decision_tree_ensemble(), dropped_trees, + ensemble_resource->decision_tree_ensemble().trees_size(), + only_finalized_trees_, &trees_to_include); + // Allocate output predictions matrix. Tensor* output_predictions_t = nullptr; OP_REQUIRES_OK( @@ -248,22 +260,13 @@ class GradientTreesPredictionOp : public OpKernel { &output_predictions_t)); auto output_predictions = output_predictions_t->matrix(); - Tensor* output_no_dropout_predictions_t = nullptr; - OP_REQUIRES_OK( - context, context->allocate_output(kNoDropoutPredictionsTensorName, - {batch_size, prediction_vector_size_}, - &output_no_dropout_predictions_t)); - auto output_no_dropout_predictions = - output_no_dropout_predictions_t->matrix(); - // Run predictor. thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; if (apply_averaging_) { DecisionTreeEnsembleConfig adjusted = - decision_tree_ensemble_resource->decision_tree_ensemble(); - + ensemble_resource->decision_tree_ensemble(); const int start_averaging = std::max( 0.0, averaging_config_.config_case() == @@ -271,21 +274,18 @@ class GradientTreesPredictionOp : public OpKernel { ? adjusted.trees_size() - averaging_config_.average_last_n_trees() : adjusted.trees_size() * (1.0 - averaging_config_.average_last_percent_trees())); - const int num_ensembles = adjusted.trees_size() - start_averaging; for (int i = start_averaging; i < adjusted.trees_size(); ++i) { float weight = adjusted.tree_weights(i); adjusted.mutable_tree_weights()->Set( i, weight * (num_ensembles - i + start_averaging) / num_ensembles); } - MultipleAdditiveTrees::Predict( - adjusted, only_finalized_trees_, dropped_trees, batch_features, - worker_threads, output_predictions, output_no_dropout_predictions); + MultipleAdditiveTrees::Predict(adjusted, trees_to_include, batch_features, + worker_threads, output_predictions); } else { MultipleAdditiveTrees::Predict( - decision_tree_ensemble_resource->decision_tree_ensemble(), - only_finalized_trees_, dropped_trees, batch_features, worker_threads, - output_predictions, output_no_dropout_predictions); + ensemble_resource->decision_tree_ensemble(), trees_to_include, + batch_features, worker_threads, output_predictions); } // Output dropped trees and original weights. @@ -327,37 +327,32 @@ class GradientTreesPartitionExamplesOp : public OpKernel { } void Compute(OpKernelContext* const context) override { - DecisionTreeEnsembleResource* decision_tree_ensemble_resource; + DecisionTreeEnsembleResource* ensemble_resource; // Gets the resource. Grabs the mutex but releases it. OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); + &ensemble_resource)); // Release the reference to the resource once we're done using it. - core::ScopedUnref unref_me(decision_tree_ensemble_resource); + core::ScopedUnref unref_me(ensemble_resource); if (use_locking_) { - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); - DoCompute(context, decision_tree_ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); + DoCompute(context, ensemble_resource); } else { - DoCompute(context, decision_tree_ensemble_resource); + DoCompute(context, ensemble_resource); } } private: - void DoCompute( - OpKernelContext* context, - DecisionTreeEnsembleResource* decision_tree_ensemble_resource) { + void DoCompute(OpKernelContext* context, + DecisionTreeEnsembleResource* ensemble_resource) { // The last non-finalized tree in the ensemble is by convention the // one to partition on. If no such tree exists, a nodeless tree is // created. - const auto& tree_ensemble = - decision_tree_ensemble_resource->decision_tree_ensemble(); - boosted_trees::trees::DecisionTreeConfig empy_tree_config; - const boosted_trees::trees::DecisionTreeConfig* tree_config = - &empy_tree_config; - auto num_trees = tree_ensemble.trees_size(); - if (num_trees > 0 && - !tree_ensemble.tree_metadata(num_trees - 1).is_finalized()) { - tree_config = &tree_ensemble.trees(num_trees - 1); - } + boosted_trees::trees::DecisionTreeConfig empty_tree_config; + const boosted_trees::trees::DecisionTreeConfig& tree_config = + (ensemble_resource->num_trees() <= 0 || + ensemble_resource->LastTreeMetadata()->is_finalized()) + ? empty_tree_config + : *ensemble_resource->LastTree(); // Read dense float features list; OpInputList dense_float_features_list; @@ -412,7 +407,7 @@ class GradientTreesPartitionExamplesOp : public OpKernel { thread::ThreadPool* const worker_threads = context->device()->tensorflow_cpu_worker_threads()->workers; learner::ExamplePartitioner::PartitionExamples( - *tree_config, batch_features, worker_threads->NumThreads(), + tree_config, batch_features, worker_threads->NumThreads(), worker_threads, partition_ids_t->vec().data()); } diff --git a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc index 3ccc36dff891d101733e66aadbe3e5744fd352f9..b08028eb635385357ba13b48d88157936978b6f1 100644 --- a/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc @@ -885,13 +885,16 @@ class BucketizeWithInputBoundariesOp : public OpKernel { VLOG(1) << "boundaries has shape: " << boundaries_tensor.shape().DebugString(); auto boundaries = boundaries_tensor.flat(); - boundaries_.clear(); + std::vector boundaries_vector; + boundaries_vector.reserve(boundaries.size()); for (size_t i = 0; i < boundaries.size(); i++) { - boundaries_.push_back(boundaries(i)); + boundaries_vector.push_back(boundaries(i)); VLOG(1) << "boundaries(" << i << ") : " << boundaries(i); } - OP_REQUIRES(context, std::is_sorted(boundaries_.begin(), boundaries_.end()), - errors::InvalidArgument("Expected sorted boundaries")); + OP_REQUIRES( + context, + std::is_sorted(boundaries_vector.begin(), boundaries_vector.end()), + errors::InvalidArgument("Expected sorted boundaries")); const Tensor& input_tensor = context->input(0); VLOG(1) << "Inputs has shape: " << input_tensor.shape().DebugString() @@ -904,21 +907,20 @@ class BucketizeWithInputBoundariesOp : public OpKernel { auto output = output_tensor->template flat(); for (size_t i = 0; i < input.size(); i++) { - output(i) = CalculateBucketIndex(input(i)); + output(i) = CalculateBucketIndex(input(i), boundaries_vector); } } private: - int32 CalculateBucketIndex(const T value) { - auto first_bigger_it = - std::upper_bound(boundaries_.begin(), boundaries_.end(), value); - int32 index = first_bigger_it - boundaries_.begin(); - CHECK(index >= 0 && index <= boundaries_.size()) + int32 CalculateBucketIndex(const T value, std::vector& boundaries_vector) { + auto first_bigger_it = std::upper_bound(boundaries_vector.begin(), + boundaries_vector.end(), value); + int32 index = first_bigger_it - boundaries_vector.begin(); + CHECK(index >= 0 && index <= boundaries_vector.size()) << "Invalid bucket index: " << index - << " boundaries_.size(): " << boundaries_.size(); + << " boundaries_vector.size(): " << boundaries_vector.size(); return index; } - std::vector boundaries_; }; #define REGISTER_KERNEL(T) \ diff --git a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc index d528757cf99c9c6dc0b3d75c765e47f5cbcff19c..2a5c7949f2d1f68eef1714c47446907038bd7216 100644 --- a/tensorflow/contrib/boosted_trees/kernels/training_ops.cc +++ b/tensorflow/contrib/boosted_trees/kernels/training_ops.cc @@ -24,14 +24,13 @@ using tensorflow::boosted_trees::learner::LearningRateDropoutDrivenConfig; namespace boosted_trees { -using boosted_trees::trees::DecisionTreeEnsembleConfig; +namespace { + +using boosted_trees::learner::LearningRateConfig; +using boosted_trees::trees::Leaf; using boosted_trees::trees::TreeNode; using boosted_trees::trees::TreeNodeMetadata; using boosted_trees::utils::DropoutUtils; -using boosted_trees::learner::LearningRateConfig; -using boosted_trees::trees::Leaf; - -namespace { // SplitCandidate holds the split candidate node along with the stats. struct SplitCandidate { @@ -187,12 +186,11 @@ class CenterTreeEnsembleBiasOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -201,7 +199,7 @@ class CenterTreeEnsembleBiasOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); // Get the next stamp token. const Tensor* next_stamp_token_t; @@ -210,28 +208,19 @@ class CenterTreeEnsembleBiasOp : public OpKernel { int64 next_stamp_token = next_stamp_token_t->scalar()(); CHECK(stamp_token != next_stamp_token); + // Update the ensemble stamp. + ensemble_resource->set_stamp(next_stamp_token); + // Get the delta updates. const Tensor* delta_updates_t; OP_REQUIRES_OK(context, context->input("delta_updates", &delta_updates_t)); - OP_REQUIRES( - context, - delta_updates_t->dim_size(0) + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Delta updates size must be consistent with label dimensions.")); auto delta_updates = delta_updates_t->vec(); - - // Update the ensemble stamp. - decision_tree_ensemble_resource->set_stamp(next_stamp_token); + const int64 logits_dimension = delta_updates_t->dim_size(0); // Get the bias. - boosted_trees::trees::Leaf* bias = - RetrieveBias(decision_tree_ensemble_resource); + boosted_trees::trees::Leaf* const bias = + RetrieveBias(ensemble_resource, logits_dimension); CHECK(bias->has_vector()); - OP_REQUIRES( - context, - bias->vector().value_size() + 1 == learner_config_.num_classes(), - errors::InvalidArgument( - "Bias vector size must be consistent with label dimensions.")); // Update the bias. float total_delta = 0; @@ -259,37 +248,29 @@ class CenterTreeEnsembleBiasOp : public OpKernel { private: // Helper method to retrieve the bias from the tree ensemble. boosted_trees::trees::Leaf* RetrieveBias( - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource) { - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - const auto num_trees = ensemble_config->trees_size(); - CHECK(num_trees == ensemble_config->tree_metadata_size() && - num_trees == ensemble_config->tree_weights_size()); + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource, + int64 logits_dimension) { + const int32 num_trees = ensemble_resource->num_trees(); if (num_trees <= 0) { - ensemble_config->mutable_growing_metadata()->set_num_trees_attempted(1); - ensemble_config->mutable_growing_metadata()->set_num_layers_attempted(1); // Add a new bias leaf. - boosted_trees::trees::DecisionTreeConfig* tree_config = - ensemble_config->add_trees(); - auto* leaf = tree_config->add_nodes()->mutable_leaf(); - for (size_t idx = 0; idx + 1 < learner_config_.num_classes(); ++idx) { - leaf->mutable_vector()->add_value(0); + ensemble_resource->IncrementAttempts(); + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->AddNewTree(1.0); + auto* const leaf = tree_config->add_nodes()->mutable_leaf(); + for (size_t idx = 0; idx < logits_dimension; ++idx) { + leaf->mutable_vector()->add_value(0.0); } - ensemble_config->add_tree_weights(1.0); - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->add_tree_metadata(); - tree_metadata->set_num_layers_grown(1); - tree_metadata->set_is_finalized(true); + ensemble_resource->LastTreeMetadata()->set_is_finalized(true); return leaf; } else if (num_trees == 1) { - // Update the existing bias. - CHECK_EQ(ensemble_config->trees(0).nodes_size(), 1); - auto* node = ensemble_config->mutable_trees(0)->mutable_nodes(0); - CHECK(node->node_case() == TreeNode::kLeaf); - return node->mutable_leaf(); + // Confirms that the only tree is a bias and returns its leaf. + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->LastTree(); + CHECK_EQ(tree_config->nodes_size(), 1); + CHECK_EQ(tree_config->nodes(0).node_case(), TreeNode::kLeaf); + return tree_config->mutable_nodes(0)->mutable_leaf(); } else { - CHECK(false) << "Unable to center bias on an already grown ensemble"; + LOG(FATAL) << "Unable to center bias on an already grown ensemble"; } } @@ -331,12 +312,11 @@ class GrowTreeEnsembleOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - mutex_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + mutex_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -345,7 +325,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); // Get the next stamp token. const Tensor* next_stamp_token_t; @@ -356,7 +336,7 @@ class GrowTreeEnsembleOp : public OpKernel { // Update the ensemble stamp regardless of whether a layer // or tree is actually grown. - decision_tree_ensemble_resource->set_stamp(next_stamp_token); + ensemble_resource->set_stamp(next_stamp_token); // Read the learning_rate. const Tensor* learning_rate_t; @@ -378,16 +358,8 @@ class GrowTreeEnsembleOp : public OpKernel { OP_REQUIRES_OK(context, context->input_list("gains", &gains_list)); OP_REQUIRES_OK(context, context->input_list("splits", &splits_list)); - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - ensemble_config->mutable_growing_metadata()->set_num_layers_attempted( - ensemble_config->growing_metadata().num_layers_attempted() + 1); - const int num_trees = ensemble_config->trees_size(); - if (num_trees <= 0 || - ensemble_config->tree_metadata(num_trees - 1).is_finalized()) { - ensemble_config->mutable_growing_metadata()->set_num_trees_attempted( - ensemble_config->growing_metadata().num_trees_attempted() + 1); - } + // Increment attempt stats. + ensemble_resource->IncrementAttempts(); // Find best splits for each active partition. std::map best_splits; @@ -400,14 +372,12 @@ class GrowTreeEnsembleOp : public OpKernel { return; } - // Update and retrieve the growable tree with its metadata. - boosted_trees::trees::DecisionTreeConfig* tree_config; - boosted_trees::trees::DecisionTreeMetadata* tree_metadata; - - // Updates the tree. If the tree is fully built and dropout was applied, it - // also adjusts the weights of dropped and the last tree. - std::tie(tree_config, tree_metadata) = UpdateAndRetrieveGrowableTree( - decision_tree_ensemble_resource, learning_rate, dropout_seed); + // Update and retrieve the growable tree. + // If the tree is fully built and dropout was applied, it also adjusts the + // weights of dropped and the last tree. + boosted_trees::trees::DecisionTreeConfig* const tree_config = + UpdateAndRetrieveGrowableTree(ensemble_resource, learning_rate, + dropout_seed); // Split tree nodes. for (auto& split_entry : best_splits) { @@ -417,16 +387,14 @@ class GrowTreeEnsembleOp : public OpKernel { // Post-prune finalized tree if needed. if (learner_config_.pruning_mode() == boosted_trees::learner::LearnerConfig::POST_PRUNE && - tree_metadata->is_finalized()) { + ensemble_resource->LastTreeMetadata()->is_finalized()) { VLOG(2) << "Post-pruning finalized tree."; PruneTree(tree_config); // If after post-pruning the whole tree has no gain, remove the tree // altogether from the ensemble. if (tree_config->nodes_size() <= 0) { - ensemble_config->mutable_trees()->RemoveLast(); - ensemble_config->mutable_tree_weights()->RemoveLast(); - ensemble_config->mutable_tree_metadata()->RemoveLast(); + ensemble_resource->RemoveLastTree(); } } } @@ -471,111 +439,88 @@ class GrowTreeEnsembleOp : public OpKernel { } void UpdateTreeWeightsIfDropout( - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config, - boosted_trees::trees::DecisionTreeMetadata* tree_metadata, + boosted_trees::models::DecisionTreeEnsembleResource* const + ensemble_resource, const uint64 dropout_seed) { // It is possible that the tree was built with dropout. If it is the case, - // we need to adjust the tree weight. - if (dropout_was_applied_ && tree_metadata->is_finalized()) { - const int32 num_trees = ensemble_config->trees_size(); - - std::vector dropped_trees; - // Since only chief builds the trees, we are sure that the other tree - // weights didn't change. - std::vector weights; - weights.reserve(num_trees); - std::vector num_updates; - num_updates.reserve(num_trees); - for (int i = 0; i < num_trees; ++i) { - weights.push_back(ensemble_config->tree_weights(i)); - num_updates.push_back( - ensemble_config->tree_metadata(i).num_tree_weight_updates()); - } + // we need to adjust the tree weight, or bail out. + if (!dropout_was_applied_ || + !ensemble_resource->LastTreeMetadata()->is_finalized()) { + return; + } + const int32 num_trees = ensemble_resource->num_trees(); - std::vector dropped_trees_weights; - // Based on seed, figure out what trees were dropped before. - std::unordered_set trees_not_to_drop; - if (center_bias_) { - trees_not_to_drop.insert(0); - } - // Last tree is the current tree that is built. - const int32 current_tree = num_trees - 1; - trees_not_to_drop.insert(current_tree); - - const auto dropout_status = DropoutUtils::DropOutTrees( - dropout_seed, dropout_config_, trees_not_to_drop, weights, - &dropped_trees, &dropped_trees_weights); - CHECK(dropout_status.ok()) - << "Can't figure out what trees were dropped out before, error is " - << dropout_status.error_message(); - - // Now we have dropped trees, update their weights and the current tree - // weight. - if (!dropped_trees.empty()) { - DropoutUtils::GetTreesWeightsForAddingTrees( - dropped_trees, dropped_trees_weights, current_tree, - 1 /* only 1 tree was added */, &weights, &num_updates); - - // Update the weights and num of updates for trees. - for (int i = 0; i < num_trees; ++i) { - ensemble_config->set_tree_weights(i, weights[i]); - ensemble_config->mutable_tree_metadata(i) - ->set_num_tree_weight_updates(num_updates[i]); - } + // Based on seed, figure out what trees were dropped before. + std::unordered_set trees_not_to_drop; + if (center_bias_) { + trees_not_to_drop.insert(0); + } + // Last tree is the current tree that is built. + const int32 current_tree = num_trees - 1; + trees_not_to_drop.insert(current_tree); + + // Since only chief builds the trees, we are sure that the other tree + // weights didn't change. + std::vector weights = ensemble_resource->GetTreeWeights(); + std::vector dropped_trees; + std::vector dropped_trees_weights; + const auto dropout_status = DropoutUtils::DropOutTrees( + dropout_seed, dropout_config_, trees_not_to_drop, weights, + &dropped_trees, &dropped_trees_weights); + CHECK(dropout_status.ok()) + << "Can't figure out what trees were dropped out before, error is " + << dropout_status.error_message(); + + // Now we have dropped trees, update their weights and the current tree + // weight. + if (!dropped_trees.empty()) { + std::vector increment_num_updates(num_trees, 0); + DropoutUtils::GetTreesWeightsForAddingTrees( + dropped_trees, dropped_trees_weights, current_tree, + 1 /* only 1 tree was added */, &weights, &increment_num_updates); + + // Update the weights and num of updates for trees. + for (int i = 0; i < num_trees; ++i) { + ensemble_resource->SetTreeWeight(i, weights[i], + increment_num_updates[i]); } } } - // Helper method to update and retrieve the growable tree which is by - // definition the last tree in the ensemble. - std::pair - UpdateAndRetrieveGrowableTree( - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource, - float learning_rate, const uint64 dropout_seed) { - boosted_trees::trees::DecisionTreeEnsembleConfig* ensemble_config = - decision_tree_ensemble_resource->mutable_decision_tree_ensemble(); - auto num_trees = ensemble_config->trees_size(); - CHECK(num_trees == ensemble_config->tree_metadata_size() && - num_trees == ensemble_config->tree_weights_size()); + // Helper method to update the growable tree which is by definition the last + // tree in the ensemble. + boosted_trees::trees::DecisionTreeConfig* UpdateAndRetrieveGrowableTree( + boosted_trees::models::DecisionTreeEnsembleResource* const + ensemble_resource, + const float learning_rate, const uint64 dropout_seed) { + const auto num_trees = ensemble_resource->num_trees(); if (num_trees <= 0 || - ensemble_config->tree_metadata(num_trees - 1).is_finalized()) { + ensemble_resource->LastTreeMetadata()->is_finalized()) { // Create a new tree with a no-op leaf. - boosted_trees::trees::DecisionTreeConfig* tree_config = - ensemble_config->add_trees(); - ++num_trees; - VLOG(1) << "Adding layer 0 to tree " << num_trees - 1 - << " of ensemble of " << num_trees << " trees."; + boosted_trees::trees::DecisionTreeConfig* const tree_config = + ensemble_resource->AddNewTree(learning_rate); + VLOG(1) << "Adding layer #0 to tree #" << num_trees << " of ensemble of " + << num_trees + 1 << " trees."; tree_config->add_nodes()->mutable_leaf(); - ensemble_config->add_tree_weights(learning_rate); - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->add_tree_metadata(); - tree_metadata->set_num_layers_grown(1); + boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = + ensemble_resource->LastTreeMetadata(); tree_metadata->set_is_finalized( learner_config_.constraints().max_tree_depth() <= 1); tree_metadata->set_num_tree_weight_updates(1); - - UpdateTreeWeightsIfDropout(ensemble_config, tree_metadata, dropout_seed); - return std::make_pair(tree_config, tree_metadata); } else { // The growable tree is by definition the last tree in the ensemble. - boosted_trees::trees::DecisionTreeMetadata* tree_metadata = - ensemble_config->mutable_tree_metadata(num_trees - 1); - auto num_layers_grown = tree_metadata->num_layers_grown(); - VLOG(1) << "Adding layer " << num_layers_grown << " to tree " + boosted_trees::trees::DecisionTreeMetadata* const tree_metadata = + ensemble_resource->LastTreeMetadata(); + const auto new_num_layers = tree_metadata->num_layers_grown() + 1; + VLOG(1) << "Adding layer #" << new_num_layers - 1 << " to tree #" << num_trees - 1 << " of ensemble of " << num_trees << " trees."; // Update growable tree metadata. - ++num_layers_grown; - tree_metadata->set_num_layers_grown(num_layers_grown); + tree_metadata->set_num_layers_grown(new_num_layers); tree_metadata->set_is_finalized( - num_layers_grown >= learner_config_.constraints().max_tree_depth()); - auto* tree_config = ensemble_config->mutable_trees(num_trees - 1); - - UpdateTreeWeightsIfDropout(ensemble_config, tree_metadata, dropout_seed); - - return std::make_pair(tree_config, tree_metadata); + new_num_layers >= learner_config_.constraints().max_tree_depth()); } + UpdateTreeWeightsIfDropout(ensemble_resource, dropout_seed); + return ensemble_resource->LastTree(); } // Helper method to merge leaf weights as the tree is being grown. @@ -763,12 +708,11 @@ class TreeEnsembleStatsOp : public OpKernel { void Compute(OpKernelContext* const context) override { // Get decision tree ensemble. - boosted_trees::models::DecisionTreeEnsembleResource* - decision_tree_ensemble_resource; + boosted_trees::models::DecisionTreeEnsembleResource* ensemble_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), - &decision_tree_ensemble_resource)); - core::ScopedUnref unref_me(decision_tree_ensemble_resource); - tf_shared_lock l(*decision_tree_ensemble_resource->get_mutex()); + &ensemble_resource)); + core::ScopedUnref unref_me(ensemble_resource); + tf_shared_lock l(*ensemble_resource->get_mutex()); // Get the stamp token. const Tensor* stamp_token_t; @@ -777,9 +721,9 @@ class TreeEnsembleStatsOp : public OpKernel { // Only the Chief should run this Op and it is guaranteed to be in // a consistent state so the stamps must always match. - CHECK(decision_tree_ensemble_resource->is_stamp_valid(stamp_token)); + CHECK(ensemble_resource->is_stamp_valid(stamp_token)); const boosted_trees::trees::DecisionTreeEnsembleConfig& ensemble_config = - decision_tree_ensemble_resource->decision_tree_ensemble(); + ensemble_resource->decision_tree_ensemble(); // Set tree stats. Tensor* num_trees_t = nullptr; @@ -794,13 +738,13 @@ class TreeEnsembleStatsOp : public OpKernel { context->allocate_output("attempted_trees", TensorShape({}), &attempted_tree_t)); - int num_trees = ensemble_config.trees_size(); + const int num_trees = ensemble_resource->num_trees(); active_tree_t->scalar()() = num_trees; - if (num_trees > 0 && - !ensemble_config.tree_metadata(num_trees - 1).is_finalized()) { - --num_trees; - } - num_trees_t->scalar()() = num_trees; + num_trees_t->scalar()() = + (num_trees <= 0 || + ensemble_resource->LastTreeMetadata()->is_finalized()) + ? num_trees + : num_trees - 1; attempted_tree_t->scalar()() = ensemble_config.growing_metadata().num_trees_attempted(); diff --git a/tensorflow/contrib/boosted_trees/lib/BUILD b/tensorflow/contrib/boosted_trees/lib/BUILD index d4d405c3a9a894e333fdf2278625d510cdeef1fe..107ff0d295bee530c1711a97849fbd3c6cdb2f00 100644 --- a/tensorflow/contrib/boosted_trees/lib/BUILD +++ b/tensorflow/contrib/boosted_trees/lib/BUILD @@ -81,6 +81,18 @@ tf_cc_test( ], ) +tf_cc_test( + name = "example_test", + size = "small", + srcs = ["utils/example_test.cc"], + deps = [ + ":utils", + "//tensorflow/core:tensor_testutil", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "batch_features_test", size = "small", @@ -132,7 +144,6 @@ tf_cc_test( ":random_tree_gen", "//tensorflow/contrib/boosted_trees/resources:decision_tree_ensemble_resource", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:tensor_testutil", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -149,7 +160,6 @@ cc_library( deps = [ ":utils", "//tensorflow/core:framework_headers_lib", - "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:testlib", ], @@ -197,7 +207,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_buffer_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -210,7 +219,6 @@ tf_cc_test( srcs = ["quantiles/weighted_quantiles_summary_test.cc"], deps = [ ":weighted_quantiles", - "//tensorflow/core", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", @@ -262,6 +270,8 @@ py_library( srcs = ["learner/batch/base_split_handler.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/contrib/boosted_trees:batch_ops_utils_py", + "//tensorflow/python:control_flow_ops", ], ) @@ -271,9 +281,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":base_split_handler", - "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:math_ops", ], ) @@ -285,7 +299,15 @@ py_test( ":categorical_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) @@ -298,7 +320,14 @@ py_library( "//tensorflow/contrib/boosted_trees:quantile_ops_py", "//tensorflow/contrib/boosted_trees:split_handler_ops_py", "//tensorflow/contrib/boosted_trees:stats_accumulator_ops_py", - "//tensorflow/contrib/boosted_trees/proto:quantiles_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:function", + "//tensorflow/python:math_ops", + "//tensorflow/python:sparse_tensor", ], ) @@ -310,7 +339,15 @@ py_test( ":ordinal_split_handler", "//tensorflow/contrib/boosted_trees/proto:learner_proto_py", "//tensorflow/contrib/boosted_trees/proto:split_info_proto_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:resources", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:tensor_shape", ], ) diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc index 16bffd9beccfad352820c805e08bec71f3705f42..43b00d4c6dc2e0066810012292874314215c41be 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc @@ -21,82 +21,14 @@ namespace tensorflow { namespace boosted_trees { namespace models { -namespace { -void CalculateTreesToKeep( - const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const std::vector& trees_to_drop, const int32 num_trees, - const bool only_finalized, std::vector* trees_to_keep) { - trees_to_keep->reserve(num_trees - trees_to_drop.size()); - - int32 index = 0; - // This assumes that trees_to_drop is a sorted list of tree ids. - for (int32 tree = 0; tree < num_trees; ++tree) { - if ((!trees_to_drop.empty() && index < trees_to_drop.size() && - trees_to_drop[index] == tree) || - (only_finalized && config.tree_metadata_size() > 0 && - !config.tree_metadata(tree).is_finalized())) { - ++index; - continue; - } - trees_to_keep->push_back(tree); - } -} - -void UpdatePredictions( - const int32 index_1, const int32 index_2, const float value, - tensorflow::TTypes::Matrix* output_predictions, - tensorflow::TTypes::Matrix* additional_output_predictions) { - (*output_predictions)(index_1, index_2) += value; - - if (additional_output_predictions != nullptr) { - (*additional_output_predictions)(index_1, index_2) += value; - } -} - -void UpdatePredictionsBasedOnTree( - const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const int32 tree_idx, const boosted_trees::utils::Example& example, - tensorflow::TTypes::Matrix* output_predictions, - tensorflow::TTypes::Matrix* additional_output_predictions) { - const boosted_trees::trees::DecisionTreeConfig& tree = config.trees(tree_idx); - const float tree_weight = config.tree_weights(tree_idx); - const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); - QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); - const auto& leaf_node = tree.nodes(leaf_idx); - QCHECK(leaf_node.has_leaf()) - << "Invalid leaf node: " << leaf_node.DebugString(); - if (leaf_node.leaf().has_sparse_vector()) { - const auto& leaf = leaf_node.leaf().sparse_vector(); - QCHECK_EQ(leaf.index_size(), leaf.value_size()); - for (size_t class_idx = 0; class_idx < leaf.index_size(); ++class_idx) { - const float value = tree_weight * leaf.value(class_idx); - - UpdatePredictions(example.example_idx, leaf.index(class_idx), value, - output_predictions, additional_output_predictions); - } - } else { - QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; - const auto& leaf = leaf_node.leaf().vector(); - for (size_t i = 0; i < leaf.value_size(); ++i) { - const float value = tree_weight * leaf.value(i); - UpdatePredictions(example.example_idx, i, value, output_predictions, - additional_output_predictions); - } - } -} - -} // namespace - void MultipleAdditiveTrees::Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const bool only_finalized_trees, const std::vector& trees_to_drop, + const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, - tensorflow::thread::ThreadPool* worker_threads, - tensorflow::TTypes::Matrix output_predictions, - tensorflow::TTypes::Matrix no_dropout_predictions) { + tensorflow::thread::ThreadPool* const worker_threads, + tensorflow::TTypes::Matrix output_predictions) { // Zero out predictions as the model is additive. output_predictions.setZero(); - no_dropout_predictions.setZero(); // Get batch size. const int64 batch_size = features.batch_size(); @@ -104,27 +36,37 @@ void MultipleAdditiveTrees::Predict( return; } - // Prepare the list of trees to keep. - std::vector trees_to_keep; - CalculateTreesToKeep(config, trees_to_drop, config.trees_size(), - only_finalized_trees, &trees_to_keep); - // Lambda for doing a block of work. - auto update_predictions = [&config, &features, &trees_to_keep, &trees_to_drop, - &output_predictions, - &no_dropout_predictions](int64 start, int64 end) { + auto update_predictions = [&config, &features, &trees_to_include, + &output_predictions](int64 start, int64 end) { auto examples_iterable = features.examples_iterable(start, end); for (const auto& example : examples_iterable) { - for (const int32 tree_idx : trees_to_keep) { - UpdatePredictionsBasedOnTree(config, tree_idx, example, - &output_predictions, - &no_dropout_predictions); - } - - // Now do predictions for dropped trees - for (const int32 tree_idx : trees_to_drop) { - UpdatePredictionsBasedOnTree(config, tree_idx, example, - &no_dropout_predictions, nullptr); + for (const int32 tree_idx : trees_to_include) { + const boosted_trees::trees::DecisionTreeConfig& tree = + config.trees(tree_idx); + const float tree_weight = config.tree_weights(tree_idx); + const int leaf_idx = trees::DecisionTree::Traverse(tree, 0, example); + QCHECK(leaf_idx >= 0) << "Invalid tree: " << tree.DebugString(); + const auto& leaf_node = tree.nodes(leaf_idx); + QCHECK(leaf_node.has_leaf()) + << "Invalid leaf node: " << leaf_node.DebugString(); + if (leaf_node.leaf().has_sparse_vector()) { + const auto& leaf = leaf_node.leaf().sparse_vector(); + QCHECK_EQ(leaf.index_size(), leaf.value_size()); + for (size_t logit_dim = 0; logit_dim < leaf.index_size(); + ++logit_dim) { + const float value = tree_weight * leaf.value(logit_dim); + output_predictions(example.example_idx, leaf.index(logit_dim)) += + value; + } + } else { + QCHECK(leaf_node.leaf().has_vector()) << "Unknown leaf type"; + const auto& leaf = leaf_node.leaf().vector(); + for (size_t i = 0; i < leaf.value_size(); ++i) { + const float value = tree_weight * leaf.value(i); + output_predictions(example.example_idx, i) += value; + } + } } } }; diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h index fedade2026137ce43ff6b1cecd21f1e6c1461960..ee29a8aa797b96d41ec2d77bf831ee287d5443e7 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.h @@ -32,15 +32,13 @@ namespace models { class MultipleAdditiveTrees { public: // Predict runs tree ensemble on the given batch and updates - // output predictions accordingly. The method also returns predictions that - // we would get if no dropout was applied. + // output predictions accordingly, for the given list of trees. static void Predict( const boosted_trees::trees::DecisionTreeEnsembleConfig& config, - const bool only_finalized_trees, const std::vector& trees_to_drop, + const std::vector& trees_to_include, const boosted_trees::utils::BatchFeatures& features, - thread::ThreadPool* const thread_pool, - TTypes::Matrix output_predictions, - TTypes::Matrix no_dropout_predictions); + tensorflow::thread::ThreadPool* const worker_threads, + tensorflow::TTypes::Matrix output_predictions); }; } // namespace models diff --git a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc index 5f0924b48f2a57c5ba8af1e564e344e8ffa1b676..4ca18bedb1054ef64c6d4b25bbad04842bab1a6a 100644 --- a/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees_test.cc @@ -57,22 +57,14 @@ TEST_F(MultipleAdditiveTreesTest, Empty) { DecisionTreeEnsembleConfig tree_ensemble_config; auto output_tensor = AsTensor({9.0f, 23.0f}, {2, 1}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_matrix = output_tensor.matrix(); // Predict for both instances. tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_EQ(0, output_matrix(0, 0)); EXPECT_EQ(0, output_matrix(1, 0)); - - // There was no dropout - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } TEST_F(MultipleAdditiveTreesTest, SingleClass) { @@ -101,89 +93,48 @@ TEST_F(MultipleAdditiveTreesTest, SingleClass) { auto output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = AsTensor({0.0f, 0.0f}, {2, 1}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - tensorflow::thread::ThreadPool threads(tensorflow::Env::Default(), "test", kNumThreadsSingleThreaded); // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } // Weighted case { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); - MultipleAdditiveTrees::Predict(weighted, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, + output_matrix); // -0.4 (bias) + 0.2 (leaf 2). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.2 * 3.2, output_matrix(0, 0)); // -0.4 (bias) + 0.9 (leaf 1). EXPECT_FLOAT_EQ(-0.4f * 6 + 0.9 * 3.2, output_matrix(1, 0)); - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - EXPECT_EQ(output_matrix(i, 0), no_dropout_output_matrix(i, 0)); - } } // Drop first tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 0)); // 0.2 (leaf 2). EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 1). - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } // Drop second tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias). EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias). - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } // Drop all trees. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0, 1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0, output_matrix(1, 0)); - - // No dropout predictions - EXPECT_FLOAT_EQ( - -0.2f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) + 0.2 (leaf 2). - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1). } } @@ -218,37 +169,22 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { auto output_tensor = AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = - AsTensor({0.0f, 0.0f, 0.0f, 0.0f}, {2, 2}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.5f, output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 1) EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - for (int j = 0; j < 2; ++j) { - EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); - } - } } // Weighted case. { DecisionTreeEnsembleConfig weighted = tree_ensemble_config; weighted.set_tree_weights(0, 6.0); weighted.set_tree_weights(1, 3.2); - MultipleAdditiveTrees::Predict(weighted, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(weighted, {0, 1}, batch_features_, &threads, + output_matrix); // bias EXPECT_FLOAT_EQ(-0.4f * 6, output_matrix(0, 0)); // bias + leaf 2 @@ -260,60 +196,30 @@ TEST_F(MultipleAdditiveTreesTest, MultiClass) { } // Dropout first tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {1}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.2f, output_matrix(0, 1)); // 0.2 (leaf 2) EXPECT_FLOAT_EQ(0.9f, output_matrix(1, 0)); // 0.9 (leaf 2) EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } // Dropout second tree. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(0, 1)); // -0.7 (bias) EXPECT_FLOAT_EQ(-0.4f, output_matrix(1, 0)); // -0.4 (bias) EXPECT_FLOAT_EQ(-0.7f, output_matrix(1, 1)); // -0.7 (bias) - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } // Drop both trees. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {0, 1}, batch_features_, &threads, - output_matrix, no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {}, batch_features_, + &threads, output_matrix); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(0, 1)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 0)); EXPECT_FLOAT_EQ(0.0f, output_matrix(1, 1)); - - // No dropout predictions - EXPECT_FLOAT_EQ(-0.4f, no_dropout_output_matrix(0, 0)); // -0.4 (bias) - EXPECT_FLOAT_EQ( - -0.5f, no_dropout_output_matrix(0, 1)); // -0.7 (bias) + 0.2 (leaf 2) - EXPECT_FLOAT_EQ( - 0.5f, no_dropout_output_matrix(1, 0)); // -0.4 (bias) + 0.9 (leaf 2) - EXPECT_FLOAT_EQ(-0.7f, no_dropout_output_matrix(1, 1)); // -0.7 (bias) } } @@ -349,29 +255,16 @@ TEST_F(MultipleAdditiveTreesTest, DenseLeaves) { AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); auto output_matrix = output_tensor.matrix(); - auto no_dropout_output_tensor = - AsTensor({0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f}, {2, 3}); - auto no_dropout_output_matrix = no_dropout_output_tensor.matrix(); - // Normal case. { - MultipleAdditiveTrees::Predict(tree_ensemble_config, - false, // include non-finalized trees - {}, batch_features_, &threads, output_matrix, - no_dropout_output_matrix); + MultipleAdditiveTrees::Predict(tree_ensemble_config, {0, 1}, + batch_features_, &threads, output_matrix); EXPECT_FLOAT_EQ(-0.2f, output_matrix(0, 0)); // -0.4 (tree1) + 0.2 (leaf 2) EXPECT_FLOAT_EQ(-0.4f, output_matrix(0, 1)); // -0.7 (tree1) + 0.3 (leaf 2) EXPECT_FLOAT_EQ(3.4f, output_matrix(0, 2)); // 3.0 -(tree1) + 0.4 (leaf 2) EXPECT_FLOAT_EQ(0.5f, output_matrix(1, 0)); // -0.4 (tree1) + 0.9 (leaf 1) EXPECT_FLOAT_EQ(0.1f, output_matrix(1, 1)); // -0.7 (tree1) + 0.8 (leaf 1) EXPECT_FLOAT_EQ(3.7f, output_matrix(1, 2)); // 3.0 (tree1) + 0.7 (leaf 1) - - // No dropout predictions are the same. - for (int i = 0; i < 2; ++i) { - for (int j = 0; j < 3; ++j) { - EXPECT_EQ(output_matrix(i, j), no_dropout_output_matrix(i, j)); - } - } } } diff --git a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h index dad3b4e10deff7b8fb3a2a393e27a5d7099984a1..c329c6d4f7363a7738b06648943fe1dbd065cce5 100644 --- a/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h +++ b/tensorflow/contrib/boosted_trees/lib/quantiles/weighted_quantiles_summary.h @@ -36,7 +36,7 @@ class WeightedQuantilesSummary { struct SummaryEntry { SummaryEntry(const ValueType& v, const WeightType& w, const WeightType& min, const WeightType& max) { - // Explicitely initialize all of memory (including padding from memory + // Explicitly initialize all of memory (including padding from memory // alignment) to allow the struct to be msan-resistant "plain old data". // // POD = http://en.cppreference.com/w/cpp/concept/PODType diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc index 9968c9c3bf12778b234c75cb1f39e04dee14b52a..f8750e7191673274772fc869c198dd5fbbefbc49 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc @@ -50,10 +50,15 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, current_node.sparse_float_binary_split_default_left().split(); auto sparse_feature = example.sparse_float_features[split.feature_column()]; - node_id = !sparse_feature.has_value() || - sparse_feature.get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + // Feature id for the split when multivalent sparse float column, or 0 + // by default. + const int32 feature_id = split.feature_id(); + + node_id = + !sparse_feature[feature_id].has_value() || + sparse_feature[feature_id].get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kSparseFloatBinarySplitDefaultRight: { @@ -61,10 +66,14 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, current_node.sparse_float_binary_split_default_right().split(); auto sparse_feature = example.sparse_float_features[split.feature_column()]; - node_id = sparse_feature.has_value() && - sparse_feature.get_value() <= split.threshold() - ? split.left_id() - : split.right_id(); + // Feature id for the split when multivalent sparse float column, or 0 + // by default. + const int32 feature_id = split.feature_id(); + node_id = + sparse_feature[feature_id].has_value() && + sparse_feature[feature_id].get_value() <= split.threshold() + ? split.left_id() + : split.right_id(); break; } case TreeNode::kCategoricalIdBinarySplit: { @@ -92,7 +101,7 @@ int DecisionTree::Traverse(const DecisionTreeConfig& config, break; } case TreeNode::NODE_NOT_SET: { - QCHECK(false) << "Invalid node in tree: " << current_node.DebugString(); + LOG(QFATAL) << "Invalid node in tree: " << current_node.DebugString(); break; } } @@ -157,7 +166,7 @@ void DecisionTree::LinkChildren(const std::vector& children, break; } case TreeNode::NODE_NOT_SET: { - QCHECK(false) << "A non-set node cannot have children."; + LOG(QFATAL) << "A non-set node cannot have children."; break; } } diff --git a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc index c55d09807eaf3a9c9db1cfbbfdfc66aec8f25155..93924d429c19aef51b6f1d85655de3798a76e3e0 100644 --- a/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/trees/decision_tree_test.cc @@ -27,13 +27,14 @@ class DecisionTreeTest : public ::testing::Test { protected: DecisionTreeTest() : batch_features_(2) { // Create a batch of two examples having one dense float, two sparse float - // and one sparse int features. + // and one sparse int features, and one sparse multi-column float feature + // (SparseFM). // The first example is missing the second sparse feature column and the // second example is missing the first sparse feature column. // This looks like the following: - // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | - // 0 | 7 | -3 | | 3 | - // 1 | -2 | | 4 | | + // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseFM (3 cols) + // 0 | 7 | -3 | | 3 | 3.0 | | 1.0 + // 1 | -2 | | 4 | | 1.5 |3.5| auto dense_float_matrix = test::AsTensor({7.0f, -2.0f}, {2, 1}); auto sparse_float_indices1 = test::AsTensor({0, 0}, {1, 2}); auto sparse_float_values1 = test::AsTensor({-3.0f}); @@ -44,11 +45,21 @@ class DecisionTreeTest : public ::testing::Test { auto sparse_int_indices1 = test::AsTensor({0, 0}, {1, 2}); auto sparse_int_values1 = test::AsTensor({3}); auto sparse_int_shape1 = test::AsTensor({2, 1}); + + // Multivalent sparse feature. + auto multi_sparse_float_indices = + test::AsTensor({0, 0, 0, 2, 1, 0, 1, 1}, {4, 2}); + auto multi_sparse_float_values = + test::AsTensor({3.0f, 1.0f, 1.5f, 3.5f}); + auto multi_sparse_float_shape = test::AsTensor({2, 3}); + TF_EXPECT_OK(batch_features_.Initialize( - {dense_float_matrix}, {sparse_float_indices1, sparse_float_indices2}, - {sparse_float_values1, sparse_float_values2}, - {sparse_float_shape1, sparse_float_shape2}, {sparse_int_indices1}, - {sparse_int_values1}, {sparse_int_shape1})); + {dense_float_matrix}, + {sparse_float_indices1, sparse_float_indices2, + multi_sparse_float_indices}, + {sparse_float_values1, sparse_float_values2, multi_sparse_float_values}, + {sparse_float_shape1, sparse_float_shape2, multi_sparse_float_shape}, + {sparse_int_indices1}, {sparse_int_values1}, {sparse_int_shape1})); } template @@ -121,44 +132,90 @@ TEST_F(DecisionTreeTest, TraverseDenseBinarySplit) { } TEST_F(DecisionTreeTest, TraverseSparseBinarySplit) { - // Test first sparse feature which is missing for the second example. - DecisionTreeConfig tree_config1; - auto* split_node1 = tree_config1.add_nodes() - ->mutable_sparse_float_binary_split_default_left() - ->mutable_split(); - split_node1->set_feature_column(0); - split_node1->set_threshold(-20.0f); - split_node1->set_left_id(1); - split_node1->set_right_id(2); - tree_config1.add_nodes()->mutable_leaf(); - tree_config1.add_nodes()->mutable_leaf(); auto example_iterable = batch_features_.examples_iterable(0, 2); - - // Expect right child to be picked as !(-3 <= -20). - auto example_it = example_iterable.begin(); - EXPECT_EQ(2, DecisionTree::Traverse(tree_config1, 0, *example_it)); - - // Expect left child to be picked as default direction. - EXPECT_EQ(1, DecisionTree::Traverse(tree_config1, 0, *++example_it)); - + // Split on SparseF1. + // Test first sparse feature which is missing for the second example. + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_left() + ->mutable_split(); + split_node->set_feature_column(0); + split_node->set_threshold(-20.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as !(-3 <= -20). + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as default direction. + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } + // Split on SparseF2. // Test second sparse feature which is missing for the first example. - DecisionTreeConfig tree_config2; - auto* split_node2 = tree_config2.add_nodes() - ->mutable_sparse_float_binary_split_default_right() - ->mutable_split(); - split_node2->set_feature_column(1); - split_node2->set_threshold(4.0f); - split_node2->set_left_id(1); - split_node2->set_right_id(2); - tree_config2.add_nodes()->mutable_leaf(); - tree_config2.add_nodes()->mutable_leaf(); - - // Expect right child to be picked as default direction. - example_it = example_iterable.begin(); - EXPECT_EQ(2, DecisionTree::Traverse(tree_config2, 0, *example_it)); - - // Expect left child to be picked as (4 <= 4). - EXPECT_EQ(1, DecisionTree::Traverse(tree_config2, 0, *++example_it)); + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node->set_feature_column(1); + split_node->set_threshold(4.0f); + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Expect right child to be picked as default direction. + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + + // Expect left child to be picked as (4 <= 4). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } + // Split on SparseFM. + // Test second sparse feature which is missing for the first example. + { + DecisionTreeConfig tree_config; + auto* split_node = tree_config.add_nodes() + ->mutable_sparse_float_binary_split_default_right() + ->mutable_split(); + split_node->set_feature_column(2); + + split_node->set_left_id(1); + split_node->set_right_id(2); + tree_config.add_nodes()->mutable_leaf(); + tree_config.add_nodes()->mutable_leaf(); + + // Split on first column + split_node->set_feature_id(0); + split_node->set_threshold(2.0f); + + // Both instances have this feature value. + auto example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + + // Split on second column + split_node->set_feature_id(1); + split_node->set_threshold(5.0f); + + // First instance does not have it (default right), second does have it. + example_it = example_iterable.begin(); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *++example_it)); + + // Split on third column + split_node->set_feature_id(2); + split_node->set_threshold(3.0f); + example_it = example_iterable.begin(); + + // First instance has it, second does not (default right). + EXPECT_EQ(1, DecisionTree::Traverse(tree_config, 0, *example_it)); + EXPECT_EQ(2, DecisionTree::Traverse(tree_config, 0, *++example_it)); + } } TEST_F(DecisionTreeTest, TraverseCategoricalIdBinarySplit) { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc index 12b377dda7852bb5a580c4ccc1d239709ef9bfc0..cf4f9a097a3368465fd4d9afb981bbaa68b4df49 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.cc @@ -94,10 +94,6 @@ Status BatchFeatures::Initialize( shape_flat(0) == batch_size_, errors::InvalidArgument( "Sparse float feature shape incompatible with batch size.")); - TF_CHECK_AND_RETURN_IF_ERROR( - shape_flat(1) <= 1, - errors::InvalidArgument( - "Sparse float features may not be multi-valent.")); auto tensor_shape = TensorShape({shape_flat(0), shape_flat(1)}); auto order_dims = sparse::SparseTensor::VarDimArray({0, 1}); sparse_float_feature_columns_.emplace_back(sparse_float_feature_indices, diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h index bb11dc9a0778c062c68433c001e7935388e0f45c..7a550d6f7328765d8815a947885e47fa0b0a8f8b 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features.h @@ -45,6 +45,22 @@ class BatchFeatures { std::vector sparse_int_feature_values_list, std::vector sparse_int_feature_shapes_list); + Status GetFeatureColumnSizes(int64* const num_dense_float_features, + int64* const num_sparse_float_features, + int64* const num_sparse_int_features) const { + QCHECK_NE(num_dense_float_features, nullptr); + QCHECK_NE(num_sparse_float_features, nullptr); + QCHECK_NE(num_sparse_int_features, nullptr); + *num_dense_float_features = dense_float_feature_columns_.size(); + *num_sparse_float_features = sparse_float_feature_columns_.size(); + *num_sparse_int_features = sparse_int_feature_columns_.size(); + if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 && + *num_sparse_int_features == 0) { + return errors::FailedPrecondition("Not intialized yet."); + } + return Status::OK(); + } + // Creates an example iterable for the requested slice. ExamplesIterable examples_iterable(int64 example_start, int64 example_end) const { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc index 7f523d527adeb60d179bfce4bc5ef32e75e34ca2..9de3e32b097a151b3bd6f5c30df2db0938b65e9c 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/batch_features_test.cc @@ -129,19 +129,6 @@ TEST_F(BatchFeaturesTest, SparseFloatFeatures_IncompatibleShape) { {sparse_float_feature_shape}, {}, {}, {})); } -TEST_F(BatchFeaturesTest, SparseFloatFeatures_Multivalent) { - BatchFeatures batch_features(2); - auto sparse_float_feature_indices = AsTensor({0, 0, 1, 0}, {2, 2}); - auto sparse_float_feature_values = AsTensor({3.0f, 7.0f}); - auto sparse_float_feature_shape = AsTensor({2, 2}); - auto expected_error = - InvalidArgument("Sparse float features may not be multi-valent."); - EXPECT_EQ(expected_error, batch_features.Initialize( - {}, {sparse_float_feature_indices}, - {sparse_float_feature_values}, - {sparse_float_feature_shape}, {}, {}, {})); -} - TEST_F(BatchFeaturesTest, SparseIntFeatures_WrongShapeIndices) { BatchFeatures batch_features(2); auto sparse_int_feature_indices = AsTensor({0, 0, 1, 0}); diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example.h b/tensorflow/contrib/boosted_trees/lib/utils/example.h index 4681eb06aa2c11a33db4d6e8ff3f0148ffd82917..e388cf332c3ff327f79ea57e3a0bccbbaa1b5e45 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/example.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/example.h @@ -16,6 +16,7 @@ #ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ #define THIRD_PARTY_TENSORFLOW_CONTRIB_BOOSTED_TREES_LIB_UTILS_EXAMPLE_H_ +#include #include #include #include "tensorflow/contrib/boosted_trees/lib/utils/optional_value.h" @@ -23,6 +24,86 @@ namespace tensorflow { namespace boosted_trees { namespace utils { +// Represents sparse vector that have a value for some feature indices within +// the feature column. +// Allows subscript access []. +template +class SparseMultidimensionalValues { + public: + void Add(const int32 feature_idx, const T value) { + values_.emplace_back(feature_idx, value); + } + + void Clear() { values_.clear(); } + + void Reserve(const int32 size) { values_.reserve(size); } + + OptionalValue operator[](int feature_idx) const { + auto value_iter = + std::find_if(values_.begin(), values_.end(), + [&feature_idx](const std::pair& element) { + return element.first == feature_idx; + }); + + if (value_iter == values_.end()) { + return OptionalValue(); + } + return OptionalValue(value_iter->second); + } + + private: + std::vector> values_; +}; + +// Represents storage for a sparse float feature column. Can store values either +// for one dimensional or a multivalent (multidimensional) sparse column. +// Allows subscript operator access [feature_id]. +template +class SparseFloatFeatureColumn { + public: + void Reserve(const int32 size) { + if (!single_dimensional_) { + mutlidimensional_values.Reserve(size); + } + } + + void SetDimension(const int32 dimension) { + single_dimensional_ = dimension <= 1; + } + + void Add(const int32 feature_idx, const float value) { + if (single_dimensional_) { + DCHECK_EQ(0, feature_idx); + single_value_ = value; + } else { + mutlidimensional_values.Add(feature_idx, value); + } + initialized_ = true; + } + + void Clear() { + single_dimensional_ = false; + initialized_ = false; + mutlidimensional_values.Clear(); + } + + OptionalValue operator[](int feature_idx) const { + if (!initialized_) { + return OptionalValue(); + } + if (single_dimensional_) { + return OptionalValue(single_value_); + } else { + return mutlidimensional_values[feature_idx]; + } + } + + private: + bool single_dimensional_; + bool initialized_; + T single_value_; + SparseMultidimensionalValues mutlidimensional_values; +}; // Holds data for one example and enables lookup by feature column. struct Example { @@ -35,7 +116,10 @@ struct Example { // Dense and sparse float features indexed by feature column. // TODO(salehay): figure out a design to support multivalent float features. std::vector dense_float_features; - std::vector> sparse_float_features; + + // Sparse float features columns (can be either single or multivalent + // (multidimensional). + std::vector> sparse_float_features; // Sparse integer features indexed by feature column. // Note that all integer features are assumed to be categorical, i.e. will diff --git a/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..be9d63ee8ae426d2d2573e7c156c62e2a3b094e1 --- /dev/null +++ b/tensorflow/contrib/boosted_trees/lib/utils/example_test.cc @@ -0,0 +1,94 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/contrib/boosted_trees/lib/utils/example.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace boosted_trees { +namespace utils { +namespace { + +class ExampleTest : public ::testing::Test {}; + +TEST_F(ExampleTest, TestSparseMatrix) { + // Create the following matrix (FC is feature column): + // FC | f0 | f1 | f2 + // multidimensional + // 0 | | 0.4 | 0.3 + // 1 | 1 | | 2 + // 2 | 3 | 1 | 5 + // 3 | | | + // one dimensional columns + // 4 | -4 + // 5 | + std::vector> matrix; + matrix.resize(6); + matrix[0].SetDimension(3); + matrix[1].SetDimension(3); + matrix[2].SetDimension(3); + matrix[3].SetDimension(3); + matrix[4].SetDimension(1); + matrix[5].SetDimension(1); + + matrix[0].Add(1, 0.4f); + matrix[0].Add(2, 0.3f); + matrix[1].Add(0, 1.f); + matrix[1].Add(2, 2.f); + matrix[2].Add(0, 3.f); + matrix[2].Add(1, 1.f); + matrix[2].Add(2, 5.f); + matrix[4].Add(0, -4.f); + + // Row 0. + EXPECT_FALSE(matrix[0][0].has_value()); + EXPECT_TRUE(matrix[0][1].has_value()); + EXPECT_EQ(0.4f, matrix[0][1].get_value()); + EXPECT_TRUE(matrix[0][2].has_value()); + EXPECT_EQ(0.3f, matrix[0][2].get_value()); + + // Row 1. + EXPECT_TRUE(matrix[1][0].has_value()); + EXPECT_EQ(1.f, matrix[1][0].get_value()); + EXPECT_FALSE(matrix[1][1].has_value()); + EXPECT_TRUE(matrix[1][2].has_value()); + EXPECT_EQ(2.f, matrix[1][2].get_value()); + + // Row 2. + EXPECT_TRUE(matrix[2][0].has_value()); + EXPECT_EQ(3.f, matrix[2][0].get_value()); + EXPECT_TRUE(matrix[2][1].has_value()); + EXPECT_EQ(1.f, matrix[2][1].get_value()); + EXPECT_TRUE(matrix[2][2].has_value()); + EXPECT_EQ(5.f, matrix[2][2].get_value()); + + // Row 3. + EXPECT_FALSE(matrix[3][0].has_value()); + EXPECT_FALSE(matrix[3][1].has_value()); + EXPECT_FALSE(matrix[3][2].has_value()); + + // Row 4. + EXPECT_TRUE(matrix[4][0].has_value()); + EXPECT_EQ(-4.f, matrix[4][0].get_value()); + + // Row 5. + EXPECT_FALSE(matrix[5][0].has_value()); +} + +} // namespace +} // namespace utils +} // namespace boosted_trees +} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc index c73dc8e15d42f2c80078cf628b5cd5773f5860ff..e7e0b568c6f3b100969c5a6263fd0c36c7803f9f 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.cc @@ -36,12 +36,14 @@ ExamplesIterable::ExamplesIterable( // Create sparse float column iterables and values. sparse_float_column_iterables_.reserve(sparse_float_feature_columns.size()); sparse_float_column_values_.reserve(sparse_float_feature_columns.size()); + sparse_float_dimensions_.reserve(sparse_float_feature_columns.size()); for (auto& sparse_float_column : sparse_float_feature_columns) { sparse_float_column_iterables_.emplace_back( sparse_float_column.indices().template matrix(), example_start, example_end); sparse_float_column_values_.emplace_back( sparse_float_column.values().template vec()); + sparse_float_dimensions_.push_back(sparse_float_column.shape()[1]); } // Create sparse int column iterables and values. @@ -73,9 +75,9 @@ Iterator::Iterator(ExamplesIterable* iter, int64 example_idx) // Pre-size example features. example_.dense_float_features.resize( iter_->dense_float_column_values_.size()); + example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size()); example_.sparse_float_features.resize( iter_->sparse_float_column_values_.size()); - example_.sparse_int_features.resize(iter_->sparse_int_column_values_.size()); } } // namespace utils diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h index 67efb82a227a3d7e92cdf5c8307a6f04c45fb617..5b33c8158879ec65425ac77b5338ee98fbdf07db 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable.h @@ -87,19 +87,52 @@ class ExamplesIterable { // Get sparse float values per column. auto& sparse_float_features = example_.sparse_float_features; + // Iterate through each sparse float feature column. for (size_t sparse_float_idx = 0; - sparse_float_idx < sparse_float_features.size(); + sparse_float_idx < iter_->sparse_float_column_iterables_.size(); ++sparse_float_idx) { + // Clear info from a previous instance. + sparse_float_features[sparse_float_idx].Clear(); + + // Get range for values tensor. const auto& row_range = (*sparse_float_column_iterators_[sparse_float_idx]); DCHECK_EQ(example_idx_, row_range.example_idx); + + // If the example has this feature column. if (row_range.start < row_range.end) { - DCHECK_EQ(1, row_range.end - row_range.start); - sparse_float_features[sparse_float_idx] = OptionalValue( - iter_->sparse_float_column_values_[sparse_float_idx]( - row_range.start)); - } else { - sparse_float_features[sparse_float_idx] = OptionalValue(); + const int32 dimension = + iter_->sparse_float_dimensions_[sparse_float_idx]; + sparse_float_features[sparse_float_idx].SetDimension(dimension); + if (dimension <= 1) { + // single dimensional sparse feature column. + DCHECK_EQ(1, row_range.end - row_range.start); + sparse_float_features[sparse_float_idx].Add( + 0, iter_->sparse_float_column_values_[sparse_float_idx]( + row_range.start)); + } else { + // Retrieve original indices tensor. + const TTypes::ConstMatrix& indices = + iter_->sparse_float_column_iterables_[sparse_float_idx] + .sparse_indices(); + + sparse_float_features[sparse_float_idx].Reserve(row_range.end - + row_range.start); + + // For each value. + for (int64 row_idx = row_range.start; row_idx < row_range.end; + ++row_idx) { + // Get the feature id for the feature column and the value. + const int32 feature_id = indices(row_idx, 1); + DCHECK_EQ(example_idx_, indices(row_idx, 0)); + + // Save the value to our sparse matrix. + sparse_float_features[sparse_float_idx].Add( + feature_id, + iter_->sparse_float_column_values_[sparse_float_idx]( + row_idx)); + } + } } } @@ -158,6 +191,9 @@ class ExamplesIterable { // Sparse float column values. std::vector::ConstVec> sparse_float_column_values_; + // Dimensions for sparse float feature columns. + std::vector sparse_float_dimensions_; + // Sparse int column iterables. std::vector sparse_int_column_iterables_; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc index d12618217ad54345b7c3975d97c70f2dc2a81733..d8a608864834b17886313a368221fbf94e31c98e 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/examples_iterable_test.cc @@ -26,17 +26,17 @@ class ExamplesIterableTest : public ::testing::Test {}; TEST_F(ExamplesIterableTest, Iterate) { // Create a batch of 8 examples having one dense float, two sparse float and - // two sparse int features. + // two sparse int features. Second sparse float feature is multivalent. // The data looks like the following: // Instance | DenseF1 | SparseF1 | SparseF2 | SparseI1 | SparseI2 | - // 0 | 7 | -3 | | 1, 8 | | - // 1 | -2 | | 4 | 0 | 7 | - // 2 | 8 | 0 | | | 13 | - // 3 | 1 | 5 | 7 | 2, 0 | 4 | - // 4 | 0 | 0 | | | 0 | - // 5 | -4 | | 9 | | | - // 6 | 7 | | | | | - // 7 | -2 | | -4 | 5 | | + // 0 | 7 | -3 | | 1 | 1, 8 | | + // 1 | -2 | | 4 | | 0 | 7 | + // 2 | 8 | 0 | | 3 | | 13 | + // 3 | 1 | 5 | 7 | | 2, 0 | 4 | + // 4 | 0 | 0 | | 4.3 | | 0 | + // 5 | -4 | | 9 | 0.8 | | | + // 6 | 7 | | | | | | + // 7 | -2 | | -4 | | 5 | | auto dense_float_tensor = test::AsTensor( {7.0f, -2.0f, 8.0f, 1.0f, 0.0f, -4.0f, 7.0f, -2.0f}, {8, 1}); auto sparse_float_indices1 = @@ -45,10 +45,11 @@ TEST_F(ExamplesIterableTest, Iterate) { auto sparse_float_shape1 = TensorShape({8, 1}); sparse::SparseTensor sparse_float_tensor1( sparse_float_indices1, sparse_float_values1, sparse_float_shape1); - auto sparse_float_indices2 = - test::AsTensor({1, 0, 3, 0, 5, 0, 7, 0}, {4, 2}); - auto sparse_float_values2 = test::AsTensor({4.0f, 7.0f, 9.0f, -4.0f}); - auto sparse_float_shape2 = TensorShape({8, 1}); + auto sparse_float_indices2 = test::AsTensor( + {0, 1, 1, 0, 2, 1, 3, 0, 4, 1, 5, 0, 5, 1, 7, 0}, {8, 2}); + auto sparse_float_values2 = + test::AsTensor({1.f, 4.0f, 3.f, 7.0f, 4.3f, 9.0f, 0.8f, -4.0f}); + auto sparse_float_shape2 = TensorShape({8, 2}); sparse::SparseTensor sparse_float_tensor2( sparse_float_indices2, sparse_float_values2, sparse_float_shape2); auto sparse_int_indices1 = @@ -67,15 +68,19 @@ TEST_F(ExamplesIterableTest, Iterate) { auto validate_example_features = [](int64 example_idx, const Example& example) { EXPECT_EQ(1, example.dense_float_features.size()); - EXPECT_EQ(2, example.sparse_float_features.size()); switch (example_idx) { case 0: { EXPECT_EQ(0, example.example_idx); EXPECT_EQ(7.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(-3.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(-3.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2 - multivalent. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1.0f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(2, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(1)); EXPECT_EQ(1, example.sparse_int_features[0].count(8)); @@ -84,9 +89,13 @@ TEST_F(ExamplesIterableTest, Iterate) { case 1: { EXPECT_EQ(1, example.example_idx); EXPECT_EQ(-2.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(4.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(4.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(0)); EXPECT_EQ(1, example.sparse_int_features[1].size()); @@ -95,9 +104,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 2: { EXPECT_EQ(2, example.example_idx); EXPECT_EQ(8.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(0.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(3.f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); EXPECT_EQ(1, example.sparse_int_features[1].count(13)); @@ -105,10 +119,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 3: { EXPECT_EQ(3, example.example_idx); EXPECT_EQ(1.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(5.0f, example.sparse_float_features[0].get_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(7.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(5.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(7.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(2, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(2)); EXPECT_EQ(1, example.sparse_int_features[0].count(0)); @@ -118,9 +136,14 @@ TEST_F(ExamplesIterableTest, Iterate) { case 4: { EXPECT_EQ(4, example.example_idx); EXPECT_EQ(0.0f, example.dense_float_features[0]); - EXPECT_TRUE(example.sparse_float_features[0].has_value()); - EXPECT_EQ(0.0f, example.sparse_float_features[0].get_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_TRUE(example.sparse_float_features[0][0].has_value()); + EXPECT_EQ(0.0f, example.sparse_float_features[0][0].get_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(4.3f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[1].size()); EXPECT_EQ(1, example.sparse_int_features[1].count(0)); @@ -128,28 +151,41 @@ TEST_F(ExamplesIterableTest, Iterate) { case 5: { EXPECT_EQ(5, example.example_idx); EXPECT_EQ(-4.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(9.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(9.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_TRUE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(0.8f, example.sparse_float_features[1][1].get_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); } break; case 6: { EXPECT_EQ(6, example.example_idx); EXPECT_EQ(7.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_FALSE(example.sparse_float_features[1].has_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_FALSE(example.sparse_float_features[1][0].has_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(0, example.sparse_int_features[0].size()); } break; case 7: { EXPECT_EQ(7, example.example_idx); EXPECT_EQ(-2.0f, example.dense_float_features[0]); - EXPECT_FALSE(example.sparse_float_features[0].has_value()); - EXPECT_TRUE(example.sparse_float_features[1].has_value()); - EXPECT_EQ(-4.0f, example.sparse_float_features[1].get_value()); + // SparseF1. + EXPECT_FALSE(example.sparse_float_features[0][0].has_value()); + // SparseF2. + EXPECT_TRUE(example.sparse_float_features[1][0].has_value()); + EXPECT_EQ(-4.0f, example.sparse_float_features[1][0].get_value()); + EXPECT_FALSE(example.sparse_float_features[1][1].has_value()); + EXPECT_EQ(1, example.sparse_int_features[0].size()); EXPECT_EQ(1, example.sparse_int_features[0].count(5)); } break; - default: { QCHECK(false) << "Invalid example index."; } break; + default: { LOG(QFATAL) << "Invalid example index."; } break; } }; @@ -158,6 +194,7 @@ TEST_F(ExamplesIterableTest, Iterate) { {dense_float_tensor}, {sparse_float_tensor1, sparse_float_tensor2}, {sparse_int_tensor1, sparse_int_tensor2}, 0, 8); int64 example_idx = 0; + for (const auto& example : full_iterable) { validate_example_features(example_idx, example); ++example_idx; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h index 78a5752730cb793394c41c56ab83b084a6f76088..9664c9d1c6a0c0c8b1bbd1506944c54d2310c611 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.h @@ -112,6 +112,8 @@ class SparseColumnIterable { int64 example_start() const { return example_start_; } int64 example_end() const { return example_end_; } + const TTypes::ConstMatrix& sparse_indices() const { return ix_; } + private: // Sparse indices matrix. TTypes::ConstMatrix ix_; diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc index 7792bd8c66c53c0f11cff113c3e5526c6d50dbb8..0138aae3dbd3773241cb6644db625b99f9bf1372 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable_test.cc @@ -34,19 +34,19 @@ TEST_F(SparseColumnIterableTest, Empty) { } TEST_F(SparseColumnIterableTest, Iterate) { - // 8 examples having 7 sparse features with the third multi-valent. + // 8 examples having 7 sparse features with the 3rd and 7th multi-valent. // This can be visualized like the following: // Instance | Sparse | - // 0 | x | + // 0 | x | // 1 | | // 2 | | // 3 | xxx | - // 4 | x | + // 4 | x | // 5 | | // 6 | | - // 7 | xx | + // 7 | x x | const auto indices = - AsTensor({0, 0, 3, 0, 3, 1, 3, 2, 4, 0, 7, 0, 7, 1}, {7, 2}); + AsTensor({0, 0, 3, 0, 3, 1, 3, 2, 4, 0, 7, 0, 7, 2}, {7, 2}); auto validate_example_range = [](const ExampleRowRange& range) { switch (range.example_idx) { diff --git a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc index be2f787fd819cc35a6a1ab8a79f3c1aceffc0a67..326e3943df722f6fb74b3a73c616dcf16af16f8d 100644 --- a/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc +++ b/tensorflow/contrib/boosted_trees/lib/utils/tensor_utils.cc @@ -95,7 +95,7 @@ int64 TensorUtils::InferBatchSize( if (sparse_int_feature_shapes_list.size() > 0) { return sparse_int_feature_shapes_list[0].flat()(0); } - QCHECK(false) << "Could not infer batch size due to empty feature set."; + LOG(QFATAL) << "Could not infer batch size due to empty feature set."; } } // namespace utils diff --git a/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc b/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc deleted file mode 100644 index b5ea5e7849dbc3aa0fe670878a8040357deda23b..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2017 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// ============================================================================= -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" - -namespace tensorflow { - -REGISTER_OP("AddTreesToEnsemble") - .Input("tree_ensemble_handle: resource") - .Input("ensemble_to_add: string") - .Input("feature_column_usage_counts_handle: Ref(int64)") - .Input("feature_column_usage_counts_to_add: int64") - .Input("feature_column_gains_handle: Ref(float)") - .Input("feature_column_gains_to_add: float") - .Input("drop_out_tree_indices_weights: float") - .Input("learning_rate: float") - .SetShapeFn(shape_inference::NoOutputs) - .Doc(R"doc( -Synchronously adds a tree ensemble to a an existing tree ensemble variable. -tree_ensemble_handle: Handle to the ensemble variable. -ensemble_to_add: Serialized DecisionTreeConfig proto of the tree. -feature_column_usage_counts_handle: Handle to the feature column usage counts variable. -feature_column_usage_counts_to_add: Rank 1 Tensor holding feature column usage counts to add. -feature_column_gains_handle: Handle to the feature column gains variable. -feature_column_gains_to_add: Rank 1 Tensor holding feature column gains to add. -drop_out_tree_indices_weights: Rank 2 Tensor containing dropped trees indices -and original weights of those trees during prediction. -learning_rate: The learning rate that the tuner found for this iteration. -)doc"); - -} // namespace tensorflow diff --git a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc index 31635906240d582f8ebbb9c8d14f1b2431409bc3..82b8e8c1c272ca415b5841f5ba9433e00173f8fa 100644 --- a/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc +++ b/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc @@ -36,10 +36,7 @@ static Status ApplyGradientTreesPredictionShapeFn(InferenceContext* c) { c->set_output(0, {c->Matrix(InferenceContext::kUnknownDim, reduce_dim ? learner_config.num_classes() - 1 : learner_config.num_classes())}); - c->set_output(1, {c->Matrix(InferenceContext::kUnknownDim, - reduce_dim ? learner_config.num_classes() - 1 - : learner_config.num_classes())}); - c->set_output(2, {c->Vector(InferenceContext::kUnknownDim)}); + c->set_output(1, {c->Vector(InferenceContext::kUnknownDim)}); return Status::OK(); } @@ -63,7 +60,6 @@ REGISTER_OP("GradientTreesPrediction") .Input("sparse_int_feature_values: num_sparse_int_features * int64") .Input("sparse_int_feature_shapes: num_sparse_int_features * int64") .Output("predictions: float") - .Output("no_dropout_predictions: float") .Output("drop_out_tree_indices_weights: float") .SetShapeFn(ApplyGradientTreesPredictionShapeFn) .Doc(R"doc( @@ -90,8 +86,6 @@ sparse_int_feature_indices: Rank 2 Tensors containing sparse int indices. sparse_int_feature_values: Rank 1 Tensors containing sparse int values. sparse_int_feature_shapes: Rank 1 Tensors containing sparse int shapes. predictions: Rank 2 Tensor containing predictions per example per class. -no_dropout_predictions: The same as predictions, but using all trees (even -those that were dropped due to dropout). drop_out_tree_indices_weights: Tensor of Rank 2 containing dropped trees indices and original weights of those trees during prediction. )doc"); diff --git a/tensorflow/contrib/boosted_trees/proto/tree_config.proto b/tensorflow/contrib/boosted_trees/proto/tree_config.proto index 2e9d45efd71adef828a55e54f48d2740b8c1a12e..f14abf45a517ad7c4c6d7bb1ab88b7a1d47d6fb6 100644 --- a/tensorflow/contrib/boosted_trees/proto/tree_config.proto +++ b/tensorflow/contrib/boosted_trees/proto/tree_config.proto @@ -53,6 +53,9 @@ message DenseFloatBinarySplit { // Float feature column and split threshold describing // the rule feature <= threshold. int32 feature_column = 1; + // If feature column is multivalent, this holds the index of the feature for + // the split. Defaults to 0. + int32 feature_id = 5; float threshold = 2; // Node children indexing into a contiguous diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/ensemble_optimizer_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/ensemble_optimizer_ops_test.py deleted file mode 100644 index 842e0caeca9734e44333a9d0ccdc3f6c9d64cfc3..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/ensemble_optimizer_ops_test.py +++ /dev/null @@ -1,351 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for the GTFlow ensemble optimization ops. - -The tests cover: -- Adding a newly built tree to an existing ensemble -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 -from tensorflow.contrib.boosted_trees.python.ops import ensemble_optimizer_ops -from tensorflow.contrib.boosted_trees.python.ops import model_ops -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops -from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import resources -from tensorflow.python.ops import variables -from tensorflow.python.platform import googletest - - -def _append_to_leaf(leaf, class_id, weight): - """Helper method for building tree leaves. - - Appends weight contributions for the given class index to a leaf node. - - Args: - leaf: leaf node to append to, int - class_id: class Id for the weight update, int - weight: weight contribution value, float - """ - leaf.sparse_vector.index.append(class_id) - leaf.sparse_vector.value.append(weight) - - -class EnsembleOptimizerOpsTest(test_util.TensorFlowTestCase): - - def setUp(self): - """Create an ensemble of 2 trees.""" - super(EnsembleOptimizerOpsTest, self).setUp() - self._tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - # First tree. - tree_1 = self._tree_ensemble.trees.add() - _append_to_leaf(tree_1.nodes.add().leaf, 0, 0.4) - _append_to_leaf(tree_1.nodes.add().leaf, 1, 0.6) - # Second tree. - tree_2 = self._tree_ensemble.trees.add() - _append_to_leaf(tree_2.nodes.add().leaf, 0, 1) - _append_to_leaf(tree_2.nodes.add().leaf, 1, 0) - - self._tree_ensemble.tree_weights.append(1.0) - self._tree_ensemble.tree_weights.append(1.0) - - meta_1 = self._tree_ensemble.tree_metadata.add() - meta_1.num_tree_weight_updates = 2 - meta_2 = self._tree_ensemble.tree_metadata.add() - meta_2.num_tree_weight_updates = 3 - - # Ensemble to be added. - self._ensemble_to_add = tree_config_pb2.DecisionTreeEnsembleConfig() - - self._tree_to_add = self._ensemble_to_add.trees.add() - _append_to_leaf(self._tree_to_add.nodes.add().leaf, 0, 0.3) - _append_to_leaf(self._tree_to_add.nodes.add().leaf, 1, 0.7) - - def testWithEmptyEnsemble(self): - with self.test_session(): - # Create an empty ensemble. - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, tree_ensemble_config="", name="empty") - - # Create zero feature importance. - feature_usage_counts = variables.Variable( - initial_value=array_ops.zeros([1], dtypes.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=array_ops.zeros([1], dtypes.float32), - name="feature_gains", - trainable=False) - - resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() - - with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( - tree_ensemble_handle, - self._ensemble_to_add.SerializeToString(), - feature_usage_counts, [2], - feature_gains, [0.4], [[]], - learning_rate=1.0) - ]): - result = model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1] - - # Output. - output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - output_ensemble.ParseFromString(result.eval()) - self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[0]) - self.assertEqual(1, len(output_ensemble.trees)) - - self.assertAllEqual([1.0], output_ensemble.tree_weights) - - self.assertEqual(1, - output_ensemble.tree_metadata[0].num_tree_weight_updates) - - self.assertAllEqual([2], feature_usage_counts.eval()) - self.assertArrayNear([0.4], feature_gains.eval(), 1e-6) - - def testWithExistingEnsemble(self): - with self.test_session(): - # Create existing tree ensemble. - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=self._tree_ensemble.SerializeToString(), - name="existing") - # Create non-zero feature importance. - feature_usage_counts = variables.Variable( - initial_value=np.array([0, 4, 1], np.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=np.array([0.0, 0.3, 0.05], np.float32), - name="feature_gains", - trainable=False) - - resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() - output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( - tree_ensemble_handle, - self._ensemble_to_add.SerializeToString(), - feature_usage_counts, [1, 2, 0], - feature_gains, [0.02, 0.1, 0.0], [[], []], - learning_rate=1) - ]): - output_ensemble.ParseFromString( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval()) - - # Output. - self.assertEqual(3, len(output_ensemble.trees)) - self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[2]) - - self.assertAllEqual([1.0, 1.0, 1.0], output_ensemble.tree_weights) - - self.assertEqual(2, - output_ensemble.tree_metadata[0].num_tree_weight_updates) - self.assertEqual(3, - output_ensemble.tree_metadata[1].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[2].num_tree_weight_updates) - self.assertAllEqual([1, 6, 1], feature_usage_counts.eval()) - self.assertArrayNear([0.02, 0.4, 0.05], feature_gains.eval(), 1e-6) - - def testWithExistingEnsembleAndDropout(self): - with self.test_session(): - tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - # Add 10 trees with some weights. - for i in range(0, 10): - tree = tree_ensemble.trees.add() - _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) - tree_ensemble.tree_weights.append(i + 1) - meta = tree_ensemble.tree_metadata.add() - meta.num_tree_weight_updates = 1 - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble.SerializeToString(), - name="existing") - # Create non-zero feature importance. - feature_usage_counts = variables.Variable( - initial_value=np.array([2, 3], np.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=np.array([0.0, 0.3], np.float32), - name="feature_gains", - trainable=False) - - resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() - - dropped = [1, 6, 8] - dropped_original_weights = [2.0, 7.0, 9.0] - - output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( - tree_ensemble_handle, - self._ensemble_to_add.SerializeToString(), - feature_usage_counts, [1, 2], - feature_gains, [0.5, 0.3], [dropped, dropped_original_weights], - learning_rate=0.1) - ]): - output_ensemble.ParseFromString( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval()) - - # Output. - self.assertEqual(11, len(output_ensemble.trees)) - self.assertProtoEquals(self._tree_to_add, output_ensemble.trees[10]) - self.assertAllClose(4.5, output_ensemble.tree_weights[10]) - - self.assertAllClose([1., 1.5, 3., 4., 5., 6., 5.25, 8., 6.75, 10., 4.5], - output_ensemble.tree_weights) - - self.assertEqual(1, - output_ensemble.tree_metadata[0].num_tree_weight_updates) - self.assertEqual(2, - output_ensemble.tree_metadata[1].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[2].num_tree_weight_updates) - - self.assertEqual(1, - output_ensemble.tree_metadata[3].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[4].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[5].num_tree_weight_updates) - self.assertEqual(2, - output_ensemble.tree_metadata[6].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[7].num_tree_weight_updates) - self.assertEqual(2, - output_ensemble.tree_metadata[8].num_tree_weight_updates) - self.assertEqual(1, - output_ensemble.tree_metadata[9].num_tree_weight_updates) - self.assertEqual( - 1, output_ensemble.tree_metadata[10].num_tree_weight_updates) - self.assertAllEqual([3, 5], feature_usage_counts.eval()) - self.assertArrayNear([0.05, 0.33], feature_gains.eval(), 1e-6) - - def testWithEmptyEnsembleAndShrinkage(self): - with self.test_session(): - # Add shrinkage config. - learning_rate = 0.0001 - tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble.SerializeToString(), - name="existing") - - # Create zero feature importance. - feature_usage_counts = variables.Variable( - initial_value=np.array([0, 0], np.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=np.array([0.0, 0.0], np.float32), - name="feature_gains", - trainable=False) - - resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() - - output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( - tree_ensemble_handle, - self._ensemble_to_add.SerializeToString(), - feature_usage_counts, [1, 2], - feature_gains, [0.5, 0.3], [[], []], - learning_rate=learning_rate) - ]): - output_ensemble.ParseFromString( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval()) - - # New tree is added with shrinkage weight. - self.assertAllClose([learning_rate], output_ensemble.tree_weights) - self.assertEqual(1, - output_ensemble.tree_metadata[0].num_tree_weight_updates) - self.assertAllEqual([1, 2], feature_usage_counts.eval()) - self.assertArrayNear([0.5 * learning_rate, 0.3 * learning_rate], - feature_gains.eval(), 1e-6) - - def testWithExistingEnsembleAndShrinkage(self): - with self.test_session(): - # Add shrinkage config. - learning_rate = 0.0001 - tree_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - # Add 10 trees with some weights. - for i in range(0, 5): - tree = tree_ensemble.trees.add() - _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) - tree_ensemble.tree_weights.append(i + 1) - meta = tree_ensemble.tree_metadata.add() - meta.num_tree_weight_updates = 1 - tree_ensemble_handle = model_ops.tree_ensemble_variable( - stamp_token=0, - tree_ensemble_config=tree_ensemble.SerializeToString(), - name="existing") - - # Create non-zero feature importance. - feature_usage_counts = variables.Variable( - initial_value=np.array([4, 7], np.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=np.array([0.2, 0.8], np.float32), - name="feature_gains", - trainable=False) - - resources.initialize_resources(resources.shared_resources()).run() - variables.initialize_all_variables().run() - - output_ensemble = tree_config_pb2.DecisionTreeEnsembleConfig() - with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( - tree_ensemble_handle, - self._ensemble_to_add.SerializeToString(), - feature_usage_counts, [1, 2], - feature_gains, [0.5, 0.3], [[], []], - learning_rate=learning_rate) - ]): - output_ensemble.ParseFromString( - model_ops.tree_ensemble_serialize(tree_ensemble_handle)[1].eval()) - - # The weights of previous trees stayed the same, new tree (LAST) is added - # with shrinkage weight. - self.assertAllClose([1.0, 2.0, 3.0, 4.0, 5.0, learning_rate], - output_ensemble.tree_weights) - - # Check that all number of updates are equal to 1 (e,g, no old tree weight - # got adjusted. - for i in range(0, 6): - self.assertEqual( - 1, output_ensemble.tree_metadata[i].num_tree_weight_updates) - - # Ensure feature importance was aggregated correctly. - self.assertAllEqual([5, 9], feature_usage_counts.eval()) - self.assertArrayNear( - [0.2 + 0.5 * learning_rate, 0.8 + 0.3 * learning_rate], - feature_gains.eval(), 1e-6) - -if __name__ == "__main__": - googletest.main() diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py index 8e628568543ee5319476669b7576124364d3a5c0..27c288bbf78b3b593d0807e92ac7fd9afc4d2725 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/model_ops_test.py @@ -30,13 +30,10 @@ import numpy as np from tensorflow.contrib.boosted_trees.proto import learner_pb2 from tensorflow.contrib.boosted_trees.proto import tree_config_pb2 -from tensorflow.contrib.boosted_trees.python.ops import ensemble_optimizer_ops from tensorflow.contrib.boosted_trees.python.ops import model_ops from tensorflow.contrib.boosted_trees.python.ops import prediction_ops -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import test_util -from tensorflow.python.ops import array_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -117,7 +114,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): name="create_tree") resources.initialize_resources(resources.shared_resources()).run() - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -178,7 +175,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle2, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -215,53 +212,36 @@ class ModelOpsTest(test_util.TensorFlowTestCase): save_path = os.path.join(self.get_temp_dir(), "restore-test") with ops.Graph().as_default() as graph: with self.test_session(graph) as sess: - tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Prepare learner config. + learner_config = learner_pb2.LearnerConfig() + learner_config.num_classes = 2 + # Add the first tree and save. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() tree = tree_ensemble_config.trees.add() tree_ensemble_config.tree_metadata.add().is_finalized = True tree_ensemble_config.tree_weights.append(1.0) _append_to_leaf(tree.nodes.add().leaf, 0, -0.1) - - tree_ensemble_config2 = tree_config_pb2.DecisionTreeEnsembleConfig() - tree2 = tree_ensemble_config2.trees.add() - tree_ensemble_config.tree_weights.append(1.0) - _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0) - - tree_ensemble_config3 = tree_config_pb2.DecisionTreeEnsembleConfig() - tree3 = tree_ensemble_config3.trees.add() - tree_ensemble_config.tree_weights.append(1.0) - _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0) - - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - tree_ensemble_handle = model_ops.tree_ensemble_variable( stamp_token=3, tree_ensemble_config=tree_ensemble_config.SerializeToString(), name="restore_tree") - feature_usage_counts = variables.Variable( - initial_value=array_ops.zeros([1], dtypes.int64), - name="feature_usage_counts", - trainable=False) - feature_gains = variables.Variable( - initial_value=array_ops.zeros([1], dtypes.float32), - name="feature_gains", - trainable=False) - resources.initialize_resources(resources.shared_resources()).run() variables.initialize_all_variables().run() my_saver = saver.Saver() + # Add the second tree and replace the ensemble of the handle. + tree2 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_weights.append(1.0) + _append_to_leaf(tree2.nodes.add().leaf, 0, -1.0) + # Predict to confirm. with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( + model_ops.tree_ensemble_deserialize( tree_ensemble_handle, - tree_ensemble_config2.SerializeToString(), - feature_usage_counts, [0], - feature_gains, [0], [[]], - learning_rate=1) + stamp_token=3, + tree_ensemble_config=tree_ensemble_config.SerializeToString()) ]): - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -280,15 +260,17 @@ class ModelOpsTest(test_util.TensorFlowTestCase): self.assertEqual(save_path, val) # Add more trees after saving. + tree3 = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_weights.append(1.0) + _append_to_leaf(tree3.nodes.add().leaf, 0, -10.0) + # Predict to confirm. with ops.control_dependencies([ - ensemble_optimizer_ops.add_trees_to_ensemble( + model_ops.tree_ensemble_deserialize( tree_ensemble_handle, - tree_ensemble_config3.SerializeToString(), - feature_usage_counts, [0], - feature_gains, [0], [[]], - learning_rate=1) + stamp_token=3, + tree_ensemble_config=tree_ensemble_config.SerializeToString()) ]): - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -311,7 +293,7 @@ class ModelOpsTest(test_util.TensorFlowTestCase): stamp_token=0, tree_ensemble_config="", name="restore_tree") my_saver = saver.Saver() my_saver.restore(sess, save_path) - result, _, _ = prediction_ops.gradient_trees_prediction( + result, _ = prediction_ops.gradient_trees_prediction( tree_ensemble_handle, self._seed, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 diff --git a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py index 37595f1c75deab4db810d6ae49b57f56f417c52f..79802922ca1b59789069a0249cee163cdd3f607a 100644 --- a/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py +++ b/tensorflow/contrib/boosted_trees/python/kernel_tests/prediction_ops_test.py @@ -136,6 +136,27 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self._sparse_int_shape1 = np.array([2, 2]) self._seed = 123 + def _get_predictions(self, + tree_ensemble_handle, + learner_config, + apply_dropout=False, + apply_averaging=False, + center_bias=False, + reduce_dim=False): + return prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + self._seed, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config, + apply_dropout=apply_dropout, + apply_averaging=apply_averaging, + center_bias=center_bias, + reduce_dim=reduce_dim) + def testEmptyEnsemble(self): with self.test_session(): # Empty tree ensenble. @@ -151,22 +172,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) self.assertAllEqual([[0], [0]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -189,22 +199,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) self.assertAllClose([[-0.4], [-0.4]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -230,22 +229,11 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 3 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) self.assertAllClose([[-0.4, 0.9], [-0.4, 0.9]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -285,27 +273,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -346,25 +323,14 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.WHOLE_TREE - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) # All the examples should get only the bias since the second tree is # non-finalized self.assertAllClose([[-0.4], [-0.4]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -405,27 +371,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.num_classes = 2 learner_config.growing_mode = learner_pb2.LearnerConfig.LAYER_BY_LAYER - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. Note that the non-finalized tree is included. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -466,27 +421,16 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config = learner_pb2.LearnerConfig() learner_config.num_classes = 2 - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) # The first example will get bias -0.4 from first tree and # leaf 4 payload of -0.9 hence -1.3, the second example will # get the same bias -0.4 and leaf 3 payload (sparse feature missing) # of 1.2 hence 0.8. self.assertAllClose([[-1.3], [0.8]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -526,26 +470,15 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.TREE_PER_CLASS) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=True) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], # the second example will get the same bias class 1 -0.2 and leaf 3 # payload of class 1 1.2 hence [0.0, 1.0]. self.assertAllClose([[0.5, -0.2], [0, 1.0]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -588,26 +521,15 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=False)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=False) # The first example will get bias class 1 -0.2 from first tree and # leaf 2 payload (sparse feature missing) of 0.5 hence [0.5, -0.2], # the second example will get the same bias class 1 -0.2 and leaf 3 # payload of class 1 1.2 and class 2-0.7 hence [0.0, 1.0, -0.7]. self.assertAllClose([[0.5, -0.2, 0.0], [0, 1.0, -0.7]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) @@ -649,55 +571,24 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): learner_config.multi_class_strategy = ( learner_pb2.LearnerConfig.FULL_HESSIAN) - result, result_no_dropout, dropout_info = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=False, - apply_averaging=False, - center_bias=False, - reduce_dim=False)) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + reduce_dim=False) # The first example will get bias class 1 -0.2 and -2 for class 2 from # first tree and leaf 2 payload (sparse feature missing) of 0.5 hence # 0.5, -0.2], the second example will get the same bias and leaf 3 payload # of class 1 1.2 and class 2-0.7 hence [0.0, 1.0, -2.7]. self.assertAllClose([[0.5, -0.2, -2.0], [0, 1.0, -2.7]], result.eval()) - self.assertAllEqual(result_no_dropout.eval(), result.eval()) # Empty dropout. self.assertAllEqual([[], []], dropout_info.eval()) - def _get_predictions(self, - tree_ensemble_handle, - learner_config, - apply_dropout=False, - apply_averaging=False, - center_bias=False): - return prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - center_bias=center_bias, - reduce_dim=True) - def testDropout(self): with self.test_session(): # Empty tree ensenble. tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() - # Add 10 trees with some weights. + # Add 1000 trees with some weights. for i in range(0, 999): tree = tree_ensemble_config.trees.add() tree_ensemble_config.tree_metadata.add().is_finalized = True @@ -717,22 +608,19 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) # We expect approx 500 trees were dropped. dropout_info = dropout_info.eval() self.assertIn(dropout_info[0].size, range(400, 601)) self.assertEqual(dropout_info[0].size, dropout_info[1].size) - self.assertEqual(result.eval().size, result_no_dropout.eval().size) - for i in range(result.eval().size): - self.assertNotEqual(result.eval()[i], result_no_dropout.eval()[i]) - for i in range(dropout_info[0].size): dropped_index = dropout_info[0][i] dropped_weight = dropout_info[1][i] @@ -741,17 +629,20 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual(dropped_index + 1, dropped_weight) # Don't apply dropout. - result, result_no_dropout, dropout_info = self._get_predictions( + result_no_dropout, no_dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) - # We expect none of the trees were dropped. - self.assertAllEqual([[], []], dropout_info.eval()) + self.assertEqual(result.eval().size, result_no_dropout.eval().size) + for i in range(result.eval().size): + self.assertNotEqual(result.eval()[i], result_no_dropout.eval()[i]) - self.assertAllEqual(result.eval(), result_no_dropout.eval()) + # We expect none of the trees were dropped. + self.assertAllEqual([[], []], no_dropout_info.eval()) def testDropoutCenterBiasNoGrowingMeta(self): # This is for normal non-batch mode where ensemble does not contain the tree @@ -780,20 +671,21 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) - result_center, result_no_dropout_center, dropout_info_center = ( - self._get_predictions( - tree_ensemble_handle, - learner_config=learner_config, - apply_dropout=True, - apply_averaging=False, - center_bias=True)) + result_center, dropout_info_center = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -820,9 +712,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual(num_trees - 1, dropout_info_center[0][num_dropped_center - 1]) - self.assertAllEqual(result_no_dropout.eval(), - result_no_dropout_center.eval()) - def testDropoutCenterBiasWithGrowingMeta(self): # This is batch mode where ensemble already contains the tree that we are # building. This tree should never be dropped. @@ -854,20 +743,21 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="existing") resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( + result, dropout_info = self._get_predictions( tree_ensemble_handle, - learner_config=learner_config, + learner_config=learner_config.SerializeToString(), apply_dropout=True, apply_averaging=False, - center_bias=False) + center_bias=False, + reduce_dim=True) - result_center, result_no_dropout_center, dropout_info_center = ( - self._get_predictions( - tree_ensemble_handle, - learner_config=learner_config, - apply_dropout=True, - apply_averaging=False, - center_bias=True)) + result_center, dropout_info_center = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=True, + reduce_dim=True) dropout_info = dropout_info.eval() dropout_info_center = dropout_info_center.eval() @@ -893,9 +783,6 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertNotEqual(num_trees - 1, dropout_info_center[0][num_dropped_center - 1]) - self.assertAllEqual(result_no_dropout.eval(), - result_no_dropout_center.eval()) - def testDropoutSeed(self): with self.test_session(): tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() @@ -918,67 +805,45 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): name="empty") resources.initialize_resources(resources.shared_resources()).run() - _, result_no_dropout_1, dropout_info_1 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) - - _, result_no_dropout_2, dropout_info_2 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + _, dropout_info_1 = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) + + _, dropout_info_2 = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # Different seed. - _, result_no_dropout_3, dropout_info_3 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - 112314, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=False, - reduce_dim=True)) + _, dropout_info_3 = prediction_ops.gradient_trees_prediction( + tree_ensemble_handle, + 112314, [self._dense_float_tensor], + [self._sparse_float_indices1, self._sparse_float_indices2], + [self._sparse_float_values1, self._sparse_float_values2], + [self._sparse_float_shape1, self._sparse_float_shape2], + [self._sparse_int_indices1], [self._sparse_int_values1], + [self._sparse_int_shape1], + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) # First seed with centering bias. - _, result_no_dropout_4, dropout_info_4 = ( - prediction_ops.gradient_trees_prediction( - tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], - learner_config=learner_config.SerializeToString(), - apply_dropout=True, - apply_averaging=False, - center_bias=True, - reduce_dim=True)) + _, dropout_info_4 = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=True, + reduce_dim=True) # The same seed returns the same results. self.assertAllEqual(dropout_info_1.eval(), dropout_info_2.eval()) @@ -991,31 +856,48 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): self.assertEqual( len(dropout_info_4.eval()[0]) + 1, len(dropout_info_1.eval()[0])) - # Predictions without dropout are all the same. - result, result_no_dropout, _ = prediction_ops.gradient_trees_prediction( + def testDropOutZeroProb(self): + with self.test_session(): + # Empty tree ensenble. + tree_ensemble_config = tree_config_pb2.DecisionTreeEnsembleConfig() + # Add 1000 trees with some weights. + for i in range(0, 999): + tree = tree_ensemble_config.trees.add() + tree_ensemble_config.tree_metadata.add().is_finalized = True + _append_to_leaf(tree.nodes.add().leaf, 0, -0.4) + tree_ensemble_config.tree_weights.append(i + 1) + + # Dropout with 0 probability. + learner_config = learner_pb2.LearnerConfig() + learner_config.learning_rate_tuner.dropout.dropout_probability = 0.0 + learner_config.learning_rate_tuner.dropout.learning_rate = 1.0 + learner_config.num_classes = 2 + + # Apply dropout, but expect nothing dropped. + tree_ensemble_handle = model_ops.tree_ensemble_variable( + stamp_token=0, + tree_ensemble_config=tree_ensemble_config.SerializeToString(), + name="existing") + resources.initialize_resources(resources.shared_resources()).run() + + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config=learner_config.SerializeToString(), + apply_dropout=True, + apply_averaging=False, + center_bias=False, + reduce_dim=True) + + result_no_dropout, _ = self._get_predictions( tree_ensemble_handle, - self._seed, [self._dense_float_tensor], [ - self._sparse_float_indices1, self._sparse_float_indices2 - ], [self._sparse_float_values1, self._sparse_float_values2], - [self._sparse_float_shape1, - self._sparse_float_shape2], [self._sparse_int_indices1], - [self._sparse_int_values1], [self._sparse_int_shape1], learner_config=learner_config.SerializeToString(), apply_dropout=False, apply_averaging=False, center_bias=False, reduce_dim=True) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_1.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_2.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_3.eval()) - self.assertAllCloseAccordingToType(result.eval(), - result_no_dropout_4.eval()) + self.assertAllEqual([[], []], dropout_info.eval()) + self.assertAllClose(result.eval(), result_no_dropout.eval()) def testAveragingAllTrees(self): with self.test_session(): @@ -1066,17 +948,18 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() # Do averaging. - result, result_no_dropout, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + pattern_result, pattern_dropout_info = self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) - self.assertAllEqual(result_no_dropout.eval(), - pattern_result_no_dropout.eval()) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1137,22 +1020,23 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() - result_1, result_no_dropout_1, dropout_info_1 = self._get_predictions( - tree_ensemble_handle, learner_config_1, apply_averaging=True) - - result_2, result_no_dropout_2, dropout_info_2 = self._get_predictions( - tree_ensemble_handle, learner_config_2, apply_averaging=True) + result_1, dropout_info_1 = self._get_predictions( + tree_ensemble_handle, + learner_config_1.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + result_2, dropout_info_2 = self._get_predictions( + tree_ensemble_handle, + learner_config_2.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - self.assertAllEqual(result_no_dropout_1.eval(), - pattern_result_no_dropout.eval()) - self.assertAllEqual(result_no_dropout_2.eval(), - pattern_result_no_dropout.eval()) + pattern_result, pattern_dropout_info = self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) self.assertAllEqual(result_1.eval(), pattern_result.eval()) self.assertAllEqual(result_2.eval(), pattern_result.eval()) @@ -1206,17 +1090,18 @@ class PredictionOpsTest(test_util.TensorFlowTestCase): resources.initialize_resources(resources.shared_resources()).run() - result, result_no_dropout, dropout_info = self._get_predictions( - tree_ensemble_handle, learner_config, apply_averaging=True) + result, dropout_info = self._get_predictions( + tree_ensemble_handle, + learner_config.SerializeToString(), + apply_averaging=True, + reduce_dim=True) - pattern_result, pattern_result_no_dropout, pattern_dropout_info = ( - self._get_predictions( - adjusted_tree_ensemble_handle, - learner_config_no_averaging, - apply_averaging=False)) + pattern_result, pattern_dropout_info = self._get_predictions( + adjusted_tree_ensemble_handle, + learner_config_no_averaging.SerializeToString(), + apply_averaging=False, + reduce_dim=True) - self.assertAllEqual(result_no_dropout.eval(), - pattern_result_no_dropout.eval()) self.assertAllEqual(result.eval(), pattern_result.eval()) self.assertAllEqual(dropout_info.eval(), pattern_dropout_info.eval()) @@ -1255,10 +1140,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1294,10 +1175,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 @@ -1333,10 +1210,6 @@ class PartitionExamplesOpsTest(test_util.TensorFlowTestCase): name="full_ensemble") resources.initialize_resources(resources.shared_resources()).run() - # Prepare learner config. - learner_config = learner_pb2.LearnerConfig() - learner_config.num_classes = 2 - result = prediction_ops.gradient_trees_partition_examples( tree_ensemble_handle, [self._dense_float_tensor], [ self._sparse_float_indices1, self._sparse_float_indices2 diff --git a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py index d1e6d98efbc588df3db7a8d8186c1135e09bbe57..58f0d36b0f78eeed6abcec1c4fa696f4ccffa615 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py +++ b/tensorflow/contrib/boosted_trees/python/ops/prediction_ops.py @@ -19,7 +19,6 @@ from __future__ import print_function # pylint: disable=unused-import from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_partition_examples +from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import gradient_trees_prediction # pylint: enable=unused-import -# pylint: disable=wildcard-import -from tensorflow.contrib.boosted_trees.python.ops.gen_prediction_ops import * -# pylint: enable=wildcard-import diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py index 2d28e0a9f160373b4565d83e9b57de401a052bd6..cebe3474ca9251971c23bde9e82564189c1ee624 100644 --- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py +++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py @@ -56,7 +56,6 @@ PREDICTIONS = "predictions" PARTITION_IDS = "partition_ids" NUM_LAYERS_ATTEMPTED = "num_layers" NUM_TREES_ATTEMPTED = "num_trees" -PREDICTIONS_NO_DROPOUT = "predictions_no_dropout" _FEATURE_NAME_TEMPLATE = "%s_%d" @@ -70,15 +69,13 @@ def _get_column_by_index(tensor, indices): return array_ops.reshape(array_ops.gather(p_flat, i_flat), [shape[0], -1]) -def _make_predictions_dict(stamp, logits, logits_no_dropout, partition_ids, - ensemble_stats): +def _make_predictions_dict(stamp, logits, partition_ids, ensemble_stats): """Returns predictions for the given logits and n_classes. Args: stamp: The ensemble stamp. logits: A rank 2 `Tensor` with shape [batch_size, n_classes - 1]. - logits_no_dropout: A rank 2 `Tensor` with shape [batch_size, n_classes - 1] - that contains predictions when no dropout was applied. + that contains predictions when no dropout was applied. partition_ids: A rank 1 `Tensor` with shape [batch_size]. ensemble_stats: A TreeEnsembleStatsOp result tuple. @@ -88,9 +85,7 @@ def _make_predictions_dict(stamp, logits, logits_no_dropout, partition_ids, result = {} result[ENSEMBLE_STAMP] = stamp result[PREDICTIONS] = logits - result[PREDICTIONS_NO_DROPOUT] = logits_no_dropout result[PARTITION_IDS] = partition_ids - result[NUM_LAYERS_ATTEMPTED] = ensemble_stats.attempted_layers result[NUM_TREES_ATTEMPTED] = ensemble_stats.attempted_trees return result @@ -213,7 +208,7 @@ def extract_features(features, feature_columns): if tensor.dtype == dtypes.float32: if len(tensor.shape) > 1 and tensor.shape[1] > 1: unstacked = array_ops.unstack(tensor, axis=1) - for i in xrange(len(unstacked)): + for i in range(len(unstacked)): dense_float_names.append(_FEATURE_NAME_TEMPLATE % (key, i)) dense_floats.append(array_ops.reshape(unstacked[i], [-1, 1])) else: @@ -348,6 +343,57 @@ class GradientBoostedDecisionTreeModel(object): learner_pb2.LearnerConfig.TREE_PER_CLASS and learner_config.num_classes == 2) + def _predict_and_return_dict(self, ensemble_handle, ensemble_stamp, mode): + """Runs prediction and returns a dictionary of the prediction results. + + Args: + ensemble_handle: ensemble resource handle. + ensemble_stamp: stamp of ensemble resource. + mode: learn.ModeKeys.TRAIN or EVAL or INFER. + + Returns: + a dictionary of prediction results - + ENSEMBLE_STAMP, PREDICTION, PARTITION_IDS, + NUM_LAYER_ATTEMPTED, NUM_TREES_ATTEMPED. + """ + ensemble_stats = training_ops.tree_ensemble_stats(ensemble_handle, + ensemble_stamp) + # We don't need dropout info - we can always restore it based on the + # seed. + apply_dropout, seed = _dropout_params(mode, ensemble_stats) + # Make sure ensemble stats run. This will check that the ensemble has + # the right stamp. + with ops.control_dependencies(ensemble_stats): + predictions, _ = prediction_ops.gradient_trees_prediction( + ensemble_handle, + seed, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + learner_config=self._learner_config_serialized, + apply_dropout=apply_dropout, + apply_averaging=mode != learn.ModeKeys.TRAIN, + use_locking=True, + center_bias=self._center_bias, + reduce_dim=self._reduce_dim) + partition_ids = prediction_ops.gradient_trees_partition_examples( + ensemble_handle, + self._dense_floats, + self._sparse_float_indices, + self._sparse_float_values, + self._sparse_float_shapes, + self._sparse_int_indices, + self._sparse_int_values, + self._sparse_int_shapes, + use_locking=True) + + return _make_predictions_dict(ensemble_stamp, predictions, partition_ids, + ensemble_stats) + def predict(self, mode): """Returns predictions given the features and mode. @@ -360,7 +406,6 @@ class GradientBoostedDecisionTreeModel(object): Raises: ValueError: if features is not valid. """ - apply_averaging = mode != learn.ModeKeys.TRAIN # Use the current ensemble to predict on the current batch of input. # For faster prediction we check if the inputs are on the same device @@ -409,83 +454,13 @@ class GradientBoostedDecisionTreeModel(object): # Once updated, use the local model for prediction. with ops.control_dependencies([refresh_local_ensemble]): - ensemble_stats = training_ops.tree_ensemble_stats( - local_ensemble_handle, ensemble_stamp) - # We don't need dropout info - we can always restore it based on the - # seed. - apply_dropout, seed = _dropout_params(mode, ensemble_stats) - # Make sure ensemble stats run. This will check that the ensemble has - # the right stamp. - with ops.control_dependencies(ensemble_stats): - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - local_ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - local_ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=True) - + return self._predict_and_return_dict(local_ensemble_handle, + ensemble_stamp, mode) else: + # Use ensemble_handle directly, if colocated. with ops.device(self._ensemble_handle.device): - ensemble_stats = training_ops.tree_ensemble_stats( - self._ensemble_handle, ensemble_stamp) - # We don't need dropout info - we can always restore it based on the - # seed. - apply_dropout, seed = _dropout_params(mode, ensemble_stats) - # Make sure ensemble stats run. This will check that the ensemble has - # the right stamp. - with ops.control_dependencies(ensemble_stats): - predictions, predictions_no_dropout, _ = ( - prediction_ops.gradient_trees_prediction( - self._ensemble_handle, - seed, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - learner_config=self._learner_config_serialized, - apply_dropout=apply_dropout, - apply_averaging=apply_averaging, - use_locking=True, - center_bias=self._center_bias, - reduce_dim=self._reduce_dim)) - partition_ids = prediction_ops.gradient_trees_partition_examples( - self._ensemble_handle, - self._dense_floats, - self._sparse_float_indices, - self._sparse_float_values, - self._sparse_float_shapes, - self._sparse_int_indices, - self._sparse_int_values, - self._sparse_int_shapes, - use_locking=True) - - return _make_predictions_dict(ensemble_stamp, predictions, - predictions_no_dropout, partition_ids, - ensemble_stats) + return self._predict_and_return_dict(self._ensemble_handle, + ensemble_stamp, mode) def train(self, loss, predictions_dict, labels): """Grows a new tree and adds it to the ensemble. @@ -519,7 +494,6 @@ class GradientBoostedDecisionTreeModel(object): gate_gradients=0, aggregation_method=None)[0] strategy = self._learner_config.multi_class_strategy - num_classes = self._learner_config.num_classes class_id = -1 # Handle different multiclass strategies. @@ -528,7 +502,7 @@ class GradientBoostedDecisionTreeModel(object): gradient_shape = tensor_shape.scalar() hessian_shape = tensor_shape.scalar() - if num_classes == 2: + if self._logits_dimension == 1: # We have only 1 score, gradients is of shape [batch, 1]. hessians = gradients_impl.gradients( gradients, @@ -546,8 +520,8 @@ class GradientBoostedDecisionTreeModel(object): hessians = array_ops.stack(hessian_list, axis=1) # Choose the class for which the tree is built (one vs rest). - class_id = predictions_dict[NUM_TREES_ATTEMPTED] % num_classes - class_id = math_ops.to_int32(class_id) + class_id = math_ops.to_int32( + predictions_dict[NUM_TREES_ATTEMPTED] % self._logits_dimension) # Use class id tensor to get the column with that index from gradients # and hessians. @@ -557,14 +531,15 @@ class GradientBoostedDecisionTreeModel(object): _get_column_by_index(hessians, class_id)) else: # Other multiclass strategies. - gradient_shape = tensor_shape.TensorShape([num_classes]) + gradient_shape = tensor_shape.TensorShape([self._logits_dimension]) if strategy == learner_pb2.LearnerConfig.FULL_HESSIAN: - hessian_shape = tensor_shape.TensorShape(([num_classes, num_classes])) + hessian_shape = tensor_shape.TensorShape( + ([self._logits_dimension, self._logits_dimension])) hessian_list = self._full_hessian(gradients, predictions) else: # Diagonal hessian strategy. - hessian_shape = tensor_shape.TensorShape(([num_classes])) + hessian_shape = tensor_shape.TensorShape(([self._logits_dimension])) hessian_list = self._diagonal_hessian(gradients, predictions) squeezed_gradients = gradients @@ -711,7 +686,7 @@ class GradientBoostedDecisionTreeModel(object): handler_results = batch_ops_utils.run_handler_scheduled_ops( handler_reads, ensemble_stamp, worker_device) per_handler_updates = {} - # Two values per handler. First one is if the the handler is active for the + # Two values per handler. First one is if the handler is active for the # current layer. The second one is if the handler is going to be active # for the next layer. subsampling_type = self._learner_config.WhichOneof("feature_fraction") @@ -803,7 +778,10 @@ class GradientBoostedDecisionTreeModel(object): active_tree, active_layer, dropout_seed, class_id), control_flow_ops.no_op)) - # Calculate the loss to be reported - use the predictions without dropout. + # Calculate the loss to be reported. + # Note, the loss is calculated from the prediction considering dropouts, so + # that the value might look staggering over steps when the dropout ratio is + # high. eval_loss might be referred instead in the aspect of convergence. return control_flow_ops.group(*ensemble_update_ops) def _get_weights(self, hessian_shape, hessians): @@ -826,10 +804,10 @@ class GradientBoostedDecisionTreeModel(object): # compute the full hessian with a single call to gradients, but instead # must compute it row-by-row. gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) hessian_rows = [] - for row in range(self._learner_config.num_classes): + for row in range(self._logits_dimension): # If current row is i, K is number of classes,each row returns a tensor of # size batch_size x K representing for each example dx_i dx_1, dx_i dx_2 # etc dx_i dx_K @@ -852,7 +830,7 @@ class GradientBoostedDecisionTreeModel(object): diag_hessian_list = [] gradients_list = array_ops.unstack( - grads, num=self._learner_config.num_classes, axis=1) + grads, num=self._logits_dimension, axis=1) for row, row_grads in enumerate(gradients_list): # If current row is i, K is number of classes,each row returns a tensor of @@ -913,7 +891,7 @@ class GradientBoostedDecisionTreeModel(object): hess_sum = math_ops.reduce_sum(hess, 0) # Accumulate gradients and hessians. - partition_ids = math_ops.range(predictions.get_shape()[1]) + partition_ids = math_ops.range(self._logits_dimension) feature_ids = array_ops.zeros_like(partition_ids, dtype=dtypes.int64) add_stats_op = bias_stats_accumulator.add( ensemble_stamp, partition_ids, feature_ids, grads_sum, hess_sum) diff --git a/tensorflow/contrib/boosted_trees/python/utils/losses.py b/tensorflow/contrib/boosted_trees/python/utils/losses.py index 4f128b230180d8e8070f63c369bc7fc2f3d24376..1e8b3ac08a74a94a0e5729e42ace91398a7b5c94 100644 --- a/tensorflow/contrib/boosted_trees/python/utils/losses.py +++ b/tensorflow/contrib/boosted_trees/python/utils/losses.py @@ -101,7 +101,10 @@ def per_example_maxent_loss(labels, weights, logits, num_classes, eps=1e-15): unweighted_loss = array_ops.expand_dims(-math_ops.log(probs_for_real_class), 1) - return unweighted_loss * weights, control_flow_ops.no_op() + if weights is None: + return unweighted_loss, control_flow_ops.no_op() + else: + return unweighted_loss * weights, control_flow_ops.no_op() def per_example_squared_loss(labels, weights, predictions): diff --git a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h index 45c3bbadfc8d6300841cbc256c894e3bb14cb44e..284ad5cdb9abf374650940ade7bb36663d72c0dd 100644 --- a/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h +++ b/tensorflow/contrib/boosted_trees/resources/decision_tree_ensemble_resource.h @@ -44,14 +44,90 @@ class DecisionTreeEnsembleResource : public StampedResource { return *decision_tree_ensemble_; } - boosted_trees::trees::DecisionTreeEnsembleConfig* - mutable_decision_tree_ensemble() { - return decision_tree_ensemble_; + int32 num_trees() const { return decision_tree_ensemble_->trees_size(); } + + bool InitFromSerialized(const string& serialized, const int64 stamp_token) { + CHECK_EQ(stamp(), -1) << "Must Reset before Init."; + if (ParseProtoUnlimited(decision_tree_ensemble_, serialized)) { + set_stamp(stamp_token); + return true; + } + return false; + } + + string SerializeAsString() const { + return decision_tree_ensemble_->SerializeAsString(); + } + + // Increment num_layers_attempted and num_trees_attempted in growing_metadata + // if the tree is finalized. + void IncrementAttempts() { + boosted_trees::trees::GrowingMetadata* const growing_metadata = + decision_tree_ensemble_->mutable_growing_metadata(); + growing_metadata->set_num_layers_attempted( + growing_metadata->num_layers_attempted() + 1); + const int num_trees = decision_tree_ensemble_->trees_size(); + if (num_trees <= 0 || LastTreeMetadata()->is_finalized()) { + growing_metadata->set_num_trees_attempted( + growing_metadata->num_trees_attempted() + 1); + } + } + + boosted_trees::trees::DecisionTreeConfig* AddNewTree(const float weight) { + // Adding a tree as well as a weight and a tree_metadata. + decision_tree_ensemble_->add_tree_weights(weight); + boosted_trees::trees::DecisionTreeMetadata* const metadata = + decision_tree_ensemble_->add_tree_metadata(); + metadata->set_num_layers_grown(1); + return decision_tree_ensemble_->add_trees(); + } + + void RemoveLastTree() { + QCHECK_GT(decision_tree_ensemble_->trees_size(), 0); + decision_tree_ensemble_->mutable_trees()->RemoveLast(); + decision_tree_ensemble_->mutable_tree_weights()->RemoveLast(); + decision_tree_ensemble_->mutable_tree_metadata()->RemoveLast(); + } + + boosted_trees::trees::DecisionTreeConfig* LastTree() { + const int32 tree_size = decision_tree_ensemble_->trees_size(); + QCHECK_GT(tree_size, 0); + return decision_tree_ensemble_->mutable_trees(tree_size - 1); + } + + boosted_trees::trees::DecisionTreeMetadata* LastTreeMetadata() { + const int32 metadata_size = decision_tree_ensemble_->tree_metadata_size(); + QCHECK_GT(metadata_size, 0); + return decision_tree_ensemble_->mutable_tree_metadata(metadata_size - 1); + } + + // Retrieves tree weights and returns as a vector. + std::vector GetTreeWeights() const { + return {decision_tree_ensemble_->tree_weights().begin(), + decision_tree_ensemble_->tree_weights().end()}; + } + + float GetTreeWeight(const int32 index) const { + return decision_tree_ensemble_->tree_weights(index); + } + + // Sets the weight of i'th tree, and increment num_updates in tree_metadata. + void SetTreeWeight(const int32 index, const float weight, + const int32 increment_num_updates) { + QCHECK_GE(index, 0); + QCHECK_LT(index, num_trees()); + decision_tree_ensemble_->set_tree_weights(index, weight); + if (increment_num_updates != 0) { + const int32 num_updates = decision_tree_ensemble_->tree_metadata(index) + .num_tree_weight_updates(); + decision_tree_ensemble_->mutable_tree_metadata(index) + ->set_num_tree_weight_updates(num_updates + increment_num_updates); + } } // Resets the resource and frees the protos in arena. // Caller needs to hold the mutex lock while calling this. - void Reset() { + virtual void Reset() { // Reset stamp. set_stamp(-1); @@ -64,7 +140,7 @@ class DecisionTreeEnsembleResource : public StampedResource { mutex* get_mutex() { return &mu_; } - private: + protected: protobuf::Arena arena_; mutex mu_; boosted_trees::trees::DecisionTreeEnsembleConfig* decision_tree_ensemble_; diff --git a/tensorflow/contrib/cloud/BUILD b/tensorflow/contrib/cloud/BUILD index eec2beddc487d67171ea43b0e46e7c8f7c11a4f3..aa8f5ed12bc6f779e3c1a923b9225ec283189747 100644 --- a/tensorflow/contrib/cloud/BUILD +++ b/tensorflow/contrib/cloud/BUILD @@ -63,11 +63,15 @@ tf_py_test( ":bigquery_reader_ops_op_lib", ":cloud_py", "//tensorflow/contrib/cloud/kernels:bigquery_reader_ops", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:data_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", ], tags = ["manual"], ) diff --git a/tensorflow/contrib/cloud/kernels/BUILD b/tensorflow/contrib/cloud/kernels/BUILD index 35bab9abfbfc34c5faa9fd1661c0151cce3374bd..56f930a9a8d32c5c3a025163ef56c9562f17d864 100644 --- a/tensorflow/contrib/cloud/kernels/BUILD +++ b/tensorflow/contrib/cloud/kernels/BUILD @@ -23,7 +23,9 @@ load( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", @@ -34,9 +36,7 @@ filegroup( tf_kernel_library( name = "bigquery_reader_ops", - srcs = [ - "bigquery_reader_ops.cc", - ], + srcs = ["bigquery_reader_ops.cc"], visibility = ["//visibility:public"], deps = [ ":bigquery_table_accessor", @@ -50,20 +50,16 @@ tf_kernel_library( cc_library( name = "bigquery_table_accessor", - srcs = [ - "bigquery_table_accessor.cc", - ], - hdrs = [ - "bigquery_table_accessor.h", - ], + srcs = ["bigquery_table_accessor.cc"], + hdrs = ["bigquery_table_accessor.h"], copts = tf_copts(), linkstatic = 1, deps = [ ":bigquery_table_partition_proto_cc", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform/cloud:curl_http_request", "//tensorflow/core/platform/cloud:google_auth_provider", - "//tensorflow/core/platform/cloud:http_request", ], alwayslink = 1, ) @@ -87,8 +83,6 @@ tf_cc_test( tf_proto_library( name = "bigquery_table_partition_proto", - srcs = [ - "bigquery_table_partition.proto", - ], + srcs = ["bigquery_table_partition.proto"], cc_api_version = 2, ) diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc index 5e95db55b62acd3b477c6c25845c19522fca87e8..51821f6653550afd2d2e8a49b7337ff8ba0b5489 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.cc @@ -142,7 +142,8 @@ BigQueryTableAccessor::BigQueryTableAccessor( project_id, dataset_id, table_id, timestamp_millis, row_buffer_size, end_point, columns, partition, std::unique_ptr(new GoogleAuthProvider()), - std::unique_ptr(new HttpRequest::Factory())) { + std::unique_ptr( + new CurlHttpRequest::Factory())) { row_buffer_.resize(row_buffer_size); } @@ -392,7 +393,7 @@ Status BigQueryTableAccessor::AppendValueToExample( } string BigQueryTableAccessor::BigQueryTableAccessor::BigQueryUriPrefix() { - HttpRequest request; + CurlHttpRequest request; return strings::StrCat(bigquery_end_point_, "/projects/", request.EscapeString(project_id_), "/datasets/", request.EscapeString(dataset_id_), "/tables/", diff --git a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h index 1cd0482186d74689b29bf577be553668fbd6565d..7d0eee59ae2f47503c4f8994ef356ce0dc336733 100644 --- a/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h +++ b/tensorflow/contrib/cloud/kernels/bigquery_table_accessor.h @@ -23,8 +23,8 @@ limitations under the License. #include "tensorflow/contrib/cloud/kernels/bigquery_table_partition.pb.h" #include "tensorflow/core/example/example.pb.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/cloud/curl_http_request.h" #include "tensorflow/core/platform/cloud/google_auth_provider.h" -#include "tensorflow/core/platform/cloud/http_request.h" namespace tensorflow { diff --git a/tensorflow/contrib/cluster_resolver/BUILD b/tensorflow/contrib/cluster_resolver/BUILD index 9501c332454238e0c4eb36d25e97f06dde9abed5..15abd2be0385eb776ff4f76484133efb6e34f076 100644 --- a/tensorflow/contrib/cluster_resolver/BUILD +++ b/tensorflow/contrib/cluster_resolver/BUILD @@ -13,7 +13,9 @@ licenses(["notice"]) # Apache 2.0 filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", @@ -37,9 +39,7 @@ py_library( py_library( name = "cluster_resolver_py", - srcs = [ - "python/training/cluster_resolver.py", - ], + srcs = ["python/training/cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/python:training", @@ -48,9 +48,7 @@ py_library( py_library( name = "gce_cluster_resolver_py", - srcs = [ - "python/training/gce_cluster_resolver.py", - ], + srcs = ["python/training/gce_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ ":cluster_resolver_py", @@ -60,9 +58,7 @@ py_library( py_library( name = "tpu_cluster_resolver_py", - srcs = [ - "python/training/tpu_cluster_resolver.py", - ], + srcs = ["python/training/tpu_cluster_resolver.py"], srcs_version = "PY2AND3", deps = [ ":cluster_resolver_py", @@ -79,6 +75,7 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:training", ], main = "python/training/cluster_resolver_test.py", ) @@ -88,11 +85,13 @@ tf_py_test( size = "small", srcs = ["python/training/gce_cluster_resolver_test.py"], additional_deps = [ + ":cluster_resolver_py", ":gce_cluster_resolver_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:training", ], main = "python/training/gce_cluster_resolver_test.py", ) @@ -107,6 +106,7 @@ tf_py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:training", ], main = "python/training/tpu_cluster_resolver_test.py", ) diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py index ceb583abe0796ec9748e752f112ce9e368bdd8c0..f0144e9faa26801b6491b242b04fda8905f15306 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver.py @@ -39,7 +39,6 @@ class TPUClusterResolver(ClusterResolver): """ def __init__(self, - api_definition, project, zone, tpu_names, @@ -52,8 +51,6 @@ class TPUClusterResolver(ClusterResolver): for the IP addresses and ports of each Cloud TPU listed. Args: - api_definition: (Alpha only) A copy of the JSON API definitions for - Cloud TPUs. This will be removed once Cloud TPU enters beta. project: Name of the GCP project containing Cloud TPUs zone: Zone where the TPUs are located tpu_names: A list of names of the target Cloud TPUs. @@ -83,14 +80,35 @@ class TPUClusterResolver(ClusterResolver): raise ImportError('googleapiclient must be installed before using the ' 'TPU cluster resolver') - # TODO(frankchn): Remove once Cloud TPU API Definitions are public and - # replace with discovery.build('tpu', 'v1') - self._service = discovery.build_from_document( - api_definition, - credentials=self._credentials) + # TODO(b/67375680): Remove custom URL once TPU APIs are finalized + self._service = discovery.build( + 'tpu', + 'v1', + credentials=self._credentials, + discoveryServiceUrl='https://storage.googleapis.com' + '/tpu-api-definition/v1alpha1.json') else: self._service = service + def get_master(self): + """Get the ClusterSpec grpc master path. + + This returns the grpc path (grpc://1.2.3.4:8470) of first instance in the + ClusterSpec returned by the cluster_spec function. This is suitable for use + for the `master` argument in tf.Session() when you are using one TPU. + + Returns: + string, the grpc path of the first instance in the ClusterSpec. + + Raises: + ValueError: If none of the TPUs specified exists. + """ + job_tasks = self.cluster_spec().job_tasks(self._job_name) + if not job_tasks: + raise ValueError('No TPUs exists with the specified names exist.') + + return 'grpc://' + job_tasks[0] + def cluster_spec(self): """Returns a ClusterSpec object based on the latest TPU information. diff --git a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py index 5bd5cd1a8702840bd3eeb264ff19810fefa1fb62..db7419be06b58e1c5737f69f2c7fd9fee44b9d95 100644 --- a/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py +++ b/tensorflow/contrib/cluster_resolver/python/training/tpu_cluster_resolver_test.py @@ -26,6 +26,28 @@ from tensorflow.python.training import server_lib mock = test.mock +class MockRequestClass(object): + + def __init__(self, name, tpu_map): + self._name = name + self._tpu_map = tpu_map + + def execute(self): + if self._name in self._tpu_map: + return self._tpu_map[self._name] + else: + raise KeyError('Resource %s was not found' % self._name) + + +class MockNodeClass(object): + + def __init__(self, tpu_map): + self._tpu_map = tpu_map + + def get(self, name): + return MockRequestClass(name, self._tpu_map) + + class TPUClusterResolverTest(test.TestCase): def _verifyClusterSpecEquality(self, cluster_spec, expected_proto): @@ -56,11 +78,15 @@ class TPUClusterResolverTest(test.TestCase): if tpu_map is None: tpu_map = {} - def get_side_effect(name): - return tpu_map[name] + mock_locations = mock.MagicMock() + mock_locations.nodes.return_value = MockNodeClass(tpu_map) + + mock_project = mock.MagicMock() + mock_project.locations.return_value = mock_locations mock_client = mock.MagicMock() - mock_client.projects.locations.nodes.get.side_effect = get_side_effect + mock_client.projects.return_value = mock_project + return mock_client def testSimpleSuccessfulRetrieval(self): @@ -109,3 +135,38 @@ class TPUClusterResolverTest(test.TestCase): tasks { key: 1 value: '10.1.2.3:8470' } } """ self._verifyClusterSpecEquality(actual_cluster_spec, expected_proto) + + def testGetMasterMultipleEntries(self): + tpu_map = { + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-1': { + 'ipAddress': '10.1.2.3', + 'port': '8470' + }, + 'projects/test-project/locations/us-central1-c/nodes/test-tpu-2': { + 'ipAddress': '10.4.5.6', + 'port': '8470' + } + } + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=['test-tpu-2', 'test-tpu-1'], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + self.assertEqual('grpc://10.4.5.6:8470', tpu_cluster_resolver.get_master()) + + def testGetMasterNoEntries(self): + tpu_map = {} + + tpu_cluster_resolver = TPUClusterResolver( + project='test-project', + zone='us-central1-c', + tpu_names=[], + credentials=None, + service=self.mock_service_client(tpu_map=tpu_map)) + with self.assertRaises(ValueError): + tpu_cluster_resolver.get_master() + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt index c249a2855622581534534a94af9991d12b73f5e9..f6b76d8af699d5497ee6374913052fd21f2c2e85 100644 --- a/tensorflow/contrib/cmake/CMakeLists.txt +++ b/tensorflow/contrib/cmake/CMakeLists.txt @@ -34,6 +34,12 @@ option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF) option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON) option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions") option(tensorflow_ENABLE_SNAPPY_SUPPORT "Enable SNAPPY compression support" ON) +if(HAIKU) + option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" OFF) +else() + option(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE "Enable PIE support" ON) +endif() + if (NOT WIN32) # Threads: defines CMAKE_THREAD_LIBS_INIT and adds -pthread compile option @@ -58,7 +64,12 @@ set (DOWNLOAD_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/downloads" CACHE PATH "Location where external projects will be downloaded.") mark_as_advanced(DOWNLOAD_LOCATION) -set(CMAKE_POSITION_INDEPENDENT_CODE ON) +if (tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +else() + set(CMAKE_POSITION_INDEPENDENT_CODE OFF) +endif() + add_definitions(-DEIGEN_AVOID_STL_ARRAY) if(WIN32) add_definitions(-DNOMINMAX -D_WIN32_WINNT=0x0A00 -DLANG_CXX11 -DCOMPILER_MSVC) @@ -217,6 +228,9 @@ endif() if(UNIX) list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CMAKE_THREAD_LIBS_INIT} ${CMAKE_DL_LIBS}) endif() +if(HAIKU) + list(APPEND tensorflow_EXTERNAL_LIBRARIES network) +endif() if (tensorflow_ENABLE_GPU) if (WIN32) @@ -245,7 +259,7 @@ if (tensorflow_ENABLE_GPU) "#define CUDA_CUDA_CONFIG_H_\n" "#define TF_CUDA_CAPABILITIES CudaVersion(\"3.0\"),CudaVersion(\"3.5\"),CudaVersion(\"5.2\")\n" "#define TF_CUDA_VERSION \"64_80\"\n" - "#define TF_CUDNN_VERSION \"64_5\"\n" + "#define TF_CUDNN_VERSION \"64_6\"\n" "#define TF_CUDA_TOOLKIT_PATH \"${CUDA_TOOLKIT_ROOT_DIR}\"\n" "#endif // CUDA_CUDA_CONFIG_H_\n" ) @@ -264,8 +278,23 @@ if (tensorflow_ENABLE_GPU) include_directories(${tensorflow_source_dir}/third_party/gpus) # add cuda libraries to tensorflow_EXTERNAL_LIBRARIES list(APPEND tensorflow_EXTERNAL_LIBRARIES ${CUDA_LIBRARIES}) - endif() -endif() + + # NOTE(mrry): Update these flags when the version of CUDA or cuDNN used + # in the default build is upgraded. + set(tensorflow_BUILD_INFO_FLAGS --build_config cuda --key_value + msvcp_dll_name=msvcp140.dll + cudart_dll_name=cudart64_80.dll + cuda_version_number=8.0 + nvcuda_dll_name=nvcuda.dll + cudnn_dll_name=cudnn64_6.dll + cudnn_version_number=6) + else(WIN32) + message(FATAL_ERROR "CMake GPU build is currently only supported on Windows.") + endif(WIN32) +else(tensorflow_ENABLE_GPU) + set(tensorflow_BUILD_INFO_FLAGS --build_config cpu --key_value + msvcp_dll_name=msvcp140.dll) +endif(tensorflow_ENABLE_GPU) # Find python executable include(FindPythonInterp) diff --git a/tensorflow/contrib/cmake/external/boringssl.cmake b/tensorflow/contrib/cmake/external/boringssl.cmake index dc27eadaca14361ffeffa6eadf6d4d97524de310..cca8444e2ae9952ea7c69a9392580ead715d363b 100644 --- a/tensorflow/contrib/cmake/external/boringssl.cmake +++ b/tensorflow/contrib/cmake/external/boringssl.cmake @@ -39,8 +39,12 @@ ExternalProject_Add(boringssl # BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) diff --git a/tensorflow/contrib/cmake/external/cub.cmake b/tensorflow/contrib/cmake/external/cub.cmake index d98579d2077f0a3bc58e6466ee830e53f44f40cb..836889895567f679d9960e29ece1600d1a7a58eb 100644 --- a/tensorflow/contrib/cmake/external/cub.cmake +++ b/tensorflow/contrib/cmake/external/cub.cmake @@ -14,8 +14,8 @@ # ============================================================================== include (ExternalProject) -set(cub_URL http://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.3.zip) -set(cub_HASH SHA256=b7ead9e291d34ffa8074243541c1380d63be63f88de23de8ee548db573b72ebe) +set(cub_URL https://mirror.bazel.build/github.com/NVlabs/cub/archive/1.7.4.zip) +set(cub_HASH SHA256=20a1a39fd97e5da7f40f5f2e7fd73fd2ea59f9dc4bb8a6c5f228aa543e727e31) set(cub_BUILD ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/cub/src/cub) set(cub_ARCHIVE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/cub_archive) diff --git a/tensorflow/contrib/cmake/external/gif.cmake b/tensorflow/contrib/cmake/external/gif.cmake index 5cb719b8787781084335779960887613df90217d..3d53c51fffcec1602a3b5553cdf3b225e3b0ae46 100644 --- a/tensorflow/contrib/cmake/external/gif.cmake +++ b/tensorflow/contrib/cmake/external/gif.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(gif_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/gif_archive/giflib-5.1.4/) -set(gif_URL http://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) +set(gif_URL https://mirror.bazel.build/ufpr.dl.sourceforge.net/project/giflib/giflib-5.1.4.tar.gz) set(gif_HASH SHA256=34a7377ba834397db019e8eb122e551a49c98f49df75ec3fcc92b9a794a4f6d1) set(gif_INSTALL ${CMAKE_BINARY_DIR}/gif/install) set(gif_BUILD ${CMAKE_BINARY_DIR}/gif/src/gif) diff --git a/tensorflow/contrib/cmake/external/jpeg.cmake b/tensorflow/contrib/cmake/external/jpeg.cmake index ff17b975b9c67139f90a0778055fd7ea98dd11bf..d9a165e856c588880ebdf996666d70c9e7f53da8 100644 --- a/tensorflow/contrib/cmake/external/jpeg.cmake +++ b/tensorflow/contrib/cmake/external/jpeg.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(jpeg_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/jpeg_archive) -set(jpeg_URL http://www.ijg.org/files/jpegsrc.v9a.tar.gz) +set(jpeg_URL https://mirror.bazel.build/www.ijg.org/files/jpegsrc.v9a.tar.gz) set(jpeg_HASH SHA256=3a753ea48d917945dd54a2d97de388aa06ca2eb1066cbfdc6652036349fe05a7) set(jpeg_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jpeg/src/jpeg) set(jpeg_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/jpeg/install) diff --git a/tensorflow/contrib/cmake/external/jsoncpp.cmake b/tensorflow/contrib/cmake/external/jsoncpp.cmake index 5127d7e8f79abdda4516eb9f006e243b7438bc65..d2ae4c76e8cd175cdc3ba41fdf4e4009f8237309 100644 --- a/tensorflow/contrib/cmake/external/jsoncpp.cmake +++ b/tensorflow/contrib/cmake/external/jsoncpp.cmake @@ -42,8 +42,12 @@ ExternalProject_Add(jsoncpp BUILD_IN_SOURCE 1 INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) diff --git a/tensorflow/contrib/cmake/external/lmdb.cmake b/tensorflow/contrib/cmake/external/lmdb.cmake index 28ec833babe8f8e600c7c0179dff511ce4d26105..e41384f023ca9fc4cba697917b491af5a9db92bc 100644 --- a/tensorflow/contrib/cmake/external/lmdb.cmake +++ b/tensorflow/contrib/cmake/external/lmdb.cmake @@ -15,7 +15,7 @@ include (ExternalProject) set(lmdb_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/lmdb) -set(lmdb_URL http://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) +set(lmdb_URL https://mirror.bazel.build/github.com/LMDB/lmdb/archive/LMDB_0.9.19.tar.gz) set(lmdb_HASH SHA256=108532fb94c6f227558d45be3f3347b52539f0f58290a7bb31ec06c462d05326) set(lmdb_BUILD ${CMAKE_BINARY_DIR}/lmdb/src/lmdb) set(lmdb_INSTALL ${CMAKE_BINARY_DIR}/lmdb/install) @@ -29,10 +29,14 @@ ExternalProject_Add(lmdb INSTALL_DIR ${lmdb_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${lmdb_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) if(WIN32) diff --git a/tensorflow/contrib/cmake/external/png.cmake b/tensorflow/contrib/cmake/external/png.cmake index 2b2bd47d1c95ca886469c525191c27f22d416c29..aad6618f52f909096fd2388e867ef3a965d033cb 100644 --- a/tensorflow/contrib/cmake/external/png.cmake +++ b/tensorflow/contrib/cmake/external/png.cmake @@ -41,10 +41,14 @@ ExternalProject_Add(png INSTALL_DIR ${png_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DCMAKE_INSTALL_PREFIX:STRING=${png_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake index d600d8c3c0d30ec517d0abc4bac94c588b5268d4..b53857a47bfbf797af02fe7f69474263119161cd 100644 --- a/tensorflow/contrib/cmake/external/protobuf.cmake +++ b/tensorflow/contrib/cmake/external/protobuf.cmake @@ -15,8 +15,8 @@ include (ExternalProject) set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src) -set(PROTOBUF_URL https://github.com/mrry/protobuf.git) # Includes MSVC fix. -set(PROTOBUF_TAG 1d2c7b6c7376f396c8c7dd9b6afd2d4f83f3cb05) +set(PROTOBUF_URL https://github.com/google/protobuf.git) +set(PROTOBUF_TAG b04e5cba356212e4e8c66c61bbe0c3a20537c5b9) if(WIN32) set(protobuf_STATIC_LIBRARIES @@ -44,8 +44,12 @@ ExternalProject_Add(protobuf ${PROTOBUF_ADDITIONAL_CMAKE_OPTIONS} INSTALL_COMMAND "" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DZLIB_ROOT:STRING=${ZLIB_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/re2.cmake b/tensorflow/contrib/cmake/external/re2.cmake index cb4ec9c2de3388ef918c75d842dab6e1f4ffee9b..b56f4b089813247f3ab1c751538ba4b05cacb5b6 100644 --- a/tensorflow/contrib/cmake/external/re2.cmake +++ b/tensorflow/contrib/cmake/external/re2.cmake @@ -38,7 +38,11 @@ ExternalProject_Add(re2 BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -) \ No newline at end of file +) diff --git a/tensorflow/contrib/cmake/external/snappy.cmake b/tensorflow/contrib/cmake/external/snappy.cmake index a35d8654fb6fa5f5b5d230ffbc061d050e5aeb5e..926c271fd9ea6e2a30251aa408bd49859ae95070 100644 --- a/tensorflow/contrib/cmake/external/snappy.cmake +++ b/tensorflow/contrib/cmake/external/snappy.cmake @@ -40,11 +40,15 @@ ExternalProject_Add(snappy LOG_CONFIGURE ON LOG_BUILD ON CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF -DSNAPPY_BUILD_TESTS:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) # actually enables snappy in the source code -add_definitions(-DSNAPPY) \ No newline at end of file +add_definitions(-DTF_USE_SNAPPY) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/external/sqlite.cmake b/tensorflow/contrib/cmake/external/sqlite.cmake index 6fa3a576998acef529942ccfab3a6a544795d712..6d06193824b32557c1d2195c940ff9c698be1bdf 100644 --- a/tensorflow/contrib/cmake/external/sqlite.cmake +++ b/tensorflow/contrib/cmake/external/sqlite.cmake @@ -53,9 +53,13 @@ else() INSTALL_DIR ${sqlite_INSTALL} DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_INSTALL_PREFIX:STRING=${sqlite_INSTALL} ) diff --git a/tensorflow/contrib/cmake/external/zlib.cmake b/tensorflow/contrib/cmake/external/zlib.cmake index c8af611e1eaefdf135551940a66985a4d50b26ed..f10f84336e8b1c0a2c7de7ea1f8b8af7c21f8b51 100644 --- a/tensorflow/contrib/cmake/external/zlib.cmake +++ b/tensorflow/contrib/cmake/external/zlib.cmake @@ -42,9 +42,13 @@ ExternalProject_Add(zlib BUILD_IN_SOURCE 1 DOWNLOAD_DIR "${DOWNLOAD_LOCATION}" CMAKE_CACHE_ARGS + if(tensorflow_ENABLE_POSITION_INDEPENDENT_CODE) + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON + else() + -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=OFF + endif() -DCMAKE_BUILD_TYPE:STRING=Release -DCMAKE_INSTALL_PREFIX:STRING=${ZLIB_INSTALL} - -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON ) # put zlib includes in the directory where they are expected diff --git a/tensorflow/contrib/cmake/tf_c.cmake b/tensorflow/contrib/cmake/tf_c.cmake index c5a101812710f0e6eb0aa8816acd2b395e7f7472..f3882e8cf76c6dad31371fc340de959c05411a2f 100644 --- a/tensorflow/contrib/cmake/tf_c.cmake +++ b/tensorflow/contrib/cmake/tf_c.cmake @@ -21,6 +21,8 @@ set(tf_c_srcs "${tensorflow_source_dir}/tensorflow/c/c_api_function.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.cc" "${tensorflow_source_dir}/tensorflow/c/eager/c_api.h" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.cc" + "${tensorflow_source_dir}/tensorflow/c/eager/tape.h" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.cc" "${tensorflow_source_dir}/tensorflow/c/eager/runtime.h" "${tensorflow_source_dir}/tensorflow/c/checkpoint_reader.cc" diff --git a/tensorflow/contrib/cmake/tf_cc_ops.cmake b/tensorflow/contrib/cmake/tf_cc_ops.cmake index 6632433087c608a65d9425e5a1efdfccc95af339..a5f5ae5478f3ca82f428d494f2822d0c69064b98 100644 --- a/tensorflow/contrib/cmake/tf_cc_ops.cmake +++ b/tensorflow/contrib/cmake/tf_cc_ops.cmake @@ -135,6 +135,8 @@ set(tf_cc_srcs "${tensorflow_source_dir}/tensorflow/cc/framework/gradient_checker.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.h" "${tensorflow_source_dir}/tensorflow/cc/framework/gradients.cc" + "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.h" + "${tensorflow_source_dir}/tensorflow/cc/framework/while_gradients.cc" ) file(GLOB_RECURSE tf_cc_test_srcs diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake index 1b64a52ecef062f9b7ef28c2b427e95b98279d08..c3dc8531bb9f0164f06841d9715f227202fdb7c9 100644 --- a/tensorflow/contrib/cmake/tf_core_framework.cmake +++ b/tensorflow/contrib/cmake/tf_core_framework.cmake @@ -191,6 +191,10 @@ file(GLOB_RECURSE tf_core_lib_srcs "${tensorflow_source_dir}/tensorflow/core/lib/*.h" "${tensorflow_source_dir}/tensorflow/core/lib/*.cc" "${tensorflow_source_dir}/tensorflow/core/public/*.h" + # TODO(@jart): Move StatusOr into core. + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.cc" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor.h" + "${tensorflow_source_dir}/tensorflow/compiler/xla/statusor_internals.h" ) file(GLOB tf_core_platform_srcs diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake index bb0d90213ad7daa6cec9879b333d74c91a0dc464..f978c8ccd5a454ca4a89de0ab5d757b566295c60 100644 --- a/tensorflow/contrib/cmake/tf_core_kernels.cmake +++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake @@ -33,6 +33,8 @@ else(tensorflow_BUILD_ALL_KERNELS) "${tensorflow_source_dir}/tensorflow/core/kernels/matmul_op.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.h" "${tensorflow_source_dir}/tensorflow/core/kernels/no_op.cc" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.h" + "${tensorflow_source_dir}/tensorflow/core/kernels/ops_util.cc" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.h" "${tensorflow_source_dir}/tensorflow/core/kernels/sendrecv_ops.cc" ) @@ -40,7 +42,6 @@ endif(tensorflow_BUILD_ALL_KERNELS) if(tensorflow_BUILD_CONTRIB_KERNELS) set(tf_contrib_kernels_srcs - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/kernels/ensemble_optimizer_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/kernels/model_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/kernels/prediction_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/kernels/quantile_ops.cc" @@ -60,13 +61,16 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/learner/stochastic/handlers/sparse-quantized-feature-column-handler.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/models/multiple_additive_trees.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/lib/trees/decision_tree.cc" - "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/model_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/kernels/prefetching_kernels.cc" + "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/clustering_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/masked_matmul_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/factorization/kernels/wals_solver_ops.cc" @@ -76,6 +80,13 @@ if(tensorflow_BUILD_CONTRIB_KERNELS) #"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/bipartite_match_op.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/single_image_random_dot_stereograms_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc" "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc" "${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc" @@ -169,6 +180,7 @@ endif(WIN32) file(GLOB_RECURSE tf_core_gpu_kernels_srcs "${tensorflow_source_dir}/tensorflow/core/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op_gpu.cu.cc" + "${tensorflow_source_dir}/tensorflow/contrib/image/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/rnn/kernels/*.cu.cc" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/kernels/*.cu.cc" ) diff --git a/tensorflow/contrib/cmake/tf_core_ops.cmake b/tensorflow/contrib/cmake/tf_core_ops.cmake index f27b2aed36ffaa57fa12cc0b47e910bc224a5396..4a61ed7a3548b1992ddc71acb8a7761e252296ea 100644 --- a/tensorflow/contrib/cmake/tf_core_ops.cmake +++ b/tensorflow/contrib/cmake/tf_core_ops.cmake @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== set(tf_op_lib_names + "audio_ops" "array_ops" "bitwise_ops" "candidate_sampling_ops" @@ -43,6 +44,7 @@ set(tf_op_lib_names "state_ops" "stateless_random_ops" "string_ops" + "summary_ops" "training_ops" ) @@ -77,14 +79,16 @@ GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_split_handler "${tensorflow_source_dir GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_training "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_prediction "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/prediction_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_quantiles "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/quantile_ops.cc") -GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_ensemble_optimzier "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(boosted_trees_stats_accumulator "${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(cudnn_rnn "${tensorflow_source_dir}/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(data_prefetching "${tensorflow_source_dir}/tensorflow/contrib/data/ops/prefetching_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_clustering "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/clustering_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(factorization_factorization "${tensorflow_source_dir}/tensorflow/contrib/factorization/ops/factorization_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(framework_variable "${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(input_pipeline "${tensorflow_source_dir}/tensorflow/contrib/input_pipeline/ops/input_pipeline_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_distort_image "${tensorflow_source_dir}/tensorflow/contrib/image/ops/distort_image_ops.cc") +GENERATE_CONTRIB_OP_LIBRARY(image_sirds "${tensorflow_source_dir}/tensorflow/contrib/image/ops/single_image_random_dot_stereograms_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(layers_sparse_feature_cross "${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc") GENERATE_CONTRIB_OP_LIBRARY(memory_stats "${tensorflow_source_dir}/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc") GENERATE_CONTRIB_OP_LIBRARY(nccl "${tensorflow_source_dir}/tensorflow/contrib/nccl/ops/nccl_ops.cc") diff --git a/tensorflow/contrib/cmake/tf_label_image_example.cmake b/tensorflow/contrib/cmake/tf_label_image_example.cmake index 0d3a4699ebb102257e8a4a816652c90ffff42d92..7f2f60b0897f62d335416f4fcffd91c1e629cf28 100644 --- a/tensorflow/contrib/cmake/tf_label_image_example.cmake +++ b/tensorflow/contrib/cmake/tf_label_image_example.cmake @@ -34,3 +34,8 @@ target_link_libraries(tf_label_image_example PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS tf_label_image_example + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) \ No newline at end of file diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 60974aca75ad262e80dc5e2c3d498ebcf447cd06..4b60460cb22b3937065b9cb7f71061019d9f0a4e 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -224,6 +224,7 @@ add_python_module("tensorflow/python/grappler") add_python_module("tensorflow/python/keras") add_python_module("tensorflow/python/keras/activations") add_python_module("tensorflow/python/keras/applications") +add_python_module("tensorflow/python/keras/applications/inception_resnet_v2") add_python_module("tensorflow/python/keras/applications/inception_v3") add_python_module("tensorflow/python/keras/applications/mobilenet") add_python_module("tensorflow/python/keras/applications/resnet50") @@ -266,12 +267,14 @@ add_python_module("tensorflow/python/keras/_impl/keras/utils") add_python_module("tensorflow/python/keras/_impl/keras/wrappers") add_python_module("tensorflow/python/kernel_tests") add_python_module("tensorflow/python/kernel_tests/distributions") +add_python_module("tensorflow/python/kernel_tests/linalg") add_python_module("tensorflow/python/layers") add_python_module("tensorflow/python/lib") add_python_module("tensorflow/python/lib/core") add_python_module("tensorflow/python/lib/io") add_python_module("tensorflow/python/ops") add_python_module("tensorflow/python/ops/distributions") +add_python_module("tensorflow/python/ops/linalg") add_python_module("tensorflow/python/ops/losses") add_python_module("tensorflow/python/platform") add_python_module("tensorflow/python/platform/default") @@ -332,6 +335,7 @@ add_python_module("tensorflow/contrib/cudnn_rnn/kernels") add_python_module("tensorflow/contrib/cudnn_rnn/ops") add_python_module("tensorflow/contrib/cudnn_rnn/python") add_python_module("tensorflow/contrib/cudnn_rnn/python/kernel_tests") +add_python_module("tensorflow/contrib/cudnn_rnn/python/layers") add_python_module("tensorflow/contrib/cudnn_rnn/python/ops") add_python_module("tensorflow/contrib/data") add_python_module("tensorflow/contrib/data/python") @@ -345,6 +349,8 @@ add_python_module("tensorflow/contrib/distributions/python") add_python_module("tensorflow/contrib/distributions/python/kernel_tests") add_python_module("tensorflow/contrib/distributions/python/ops") add_python_module("tensorflow/contrib/distributions/python/ops/bijectors") +add_python_module("tensorflow/contrib/eager") +add_python_module("tensorflow/contrib/eager/python") add_python_module("tensorflow/contrib/estimator") add_python_module("tensorflow/contrib/estimator/python") add_python_module("tensorflow/contrib/estimator/python/estimator") @@ -370,6 +376,8 @@ add_python_module("tensorflow/contrib/gan/python/eval") add_python_module("tensorflow/contrib/gan/python/eval/python") add_python_module("tensorflow/contrib/gan/python/features") add_python_module("tensorflow/contrib/gan/python/features/python") +add_python_module("tensorflow/contrib/gan/python/estimator") +add_python_module("tensorflow/contrib/gan/python/estimator/python") add_python_module("tensorflow/contrib/gan/python/losses") add_python_module("tensorflow/contrib/gan/python/losses/python") add_python_module("tensorflow/contrib/graph_editor") @@ -448,6 +456,10 @@ add_python_module("tensorflow/contrib/keras/python/keras/wrappers") add_python_module("tensorflow/contrib/kernel_methods") add_python_module("tensorflow/contrib/kernel_methods/python") add_python_module("tensorflow/contrib/kernel_methods/python/mappers") +add_python_module("tensorflow/contrib/kfac") +add_python_module("tensorflow/contrib/kfac/examples") +add_python_module("tensorflow/contrib/kfac/python") +add_python_module("tensorflow/contrib/kfac/python/ops") add_python_module("tensorflow/contrib/labeled_tensor") add_python_module("tensorflow/contrib/labeled_tensor/python") add_python_module("tensorflow/contrib/labeled_tensor/python/ops") @@ -491,6 +503,7 @@ add_python_module("tensorflow/contrib/lookup") add_python_module("tensorflow/contrib/losses") add_python_module("tensorflow/contrib/losses/python") add_python_module("tensorflow/contrib/losses/python/losses") +add_python_module("tensorflow/contrib/losses/python/metric_learning") add_python_module("tensorflow/contrib/makefile") add_python_module("tensorflow/contrib/makefile/test") add_python_module("tensorflow/contrib/memory_stats") @@ -507,6 +520,11 @@ add_python_module("tensorflow/contrib/metrics/python") add_python_module("tensorflow/contrib/metrics/python/kernel_tests") add_python_module("tensorflow/contrib/metrics/python/metrics") add_python_module("tensorflow/contrib/metrics/python/ops") +add_python_module("tensorflow/contrib/model_pruning") +add_python_module("tensorflow/contrib/model_pruning/examples") +add_python_module("tensorflow/contrib/model_pruning/examples/cifar10") +add_python_module("tensorflow/contrib/model_pruning/python") +add_python_module("tensorflow/contrib/model_pruning/python/layers") add_python_module("tensorflow/contrib/ndlstm") add_python_module("tensorflow/contrib/ndlstm/python") add_python_module("tensorflow/contrib/nn") @@ -532,6 +550,8 @@ add_python_module("tensorflow/contrib/pi_examples/label_image/data") add_python_module("tensorflow/contrib/predictor") add_python_module("tensorflow/contrib/quantization") add_python_module("tensorflow/contrib/quantization/python") +add_python_module("tensorflow/contrib/quantize") +add_python_module("tensorflow/contrib/quantize/python") add_python_module("tensorflow/contrib/remote_fused_graph/pylib") add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python") add_python_module("tensorflow/contrib/remote_fused_graph/pylib/python/ops") @@ -629,16 +649,12 @@ add_python_module("tensorflow/contrib/reduce_slice_ops/ops") add_python_module("tensorflow/contrib/reduce_slice_ops/python") add_python_module("tensorflow/contrib/reduce_slice_ops/python/kernel_tests") add_python_module("tensorflow/contrib/reduce_slice_ops/python/ops") +add_python_module("tensorflow/contrib/summary") # Generate the tensorflow.python.platform.build_info module. set(BUILD_INFO_PY "${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/platform/build_info.py") -if(tensorflow_ENABLE_GPU) - set(BUILD_CONFIG_STRING "cuda") -else(tensorflow_ENABLE_GPU) - set(BUILD_CONFIG_STRING "cpu") -endif(tensorflow_ENABLE_GPU) add_custom_command(TARGET tf_python_copy_scripts_to_destination PRE_BUILD - COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/build_info/gen_build_info.py --build_config ${BUILD_CONFIG_STRING} --raw_generate ${BUILD_INFO_PY}) + COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/build_info/gen_build_info.py --raw_generate ${BUILD_INFO_PY} ${tensorflow_BUILD_INFO_FLAGS}) ######################################################## @@ -713,6 +729,7 @@ function(GENERATE_PYTHON_OP_LIB tf_python_op_lib_name) ${GENERATE_PYTHON_OP_LIB_DESTINATION} PARENT_SCOPE) endfunction() +GENERATE_PYTHON_OP_LIB("audio_ops") GENERATE_PYTHON_OP_LIB("array_ops") GENERATE_PYTHON_OP_LIB("bitwise_ops") GENERATE_PYTHON_OP_LIB("math_ops") @@ -756,12 +773,12 @@ GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_prediction_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_prediction_ops.py) GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_quantiles_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_quantile_ops.py) -GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_ensemble_optimzier_ops" - DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_ensemble_optimizer_ops.py) GENERATE_PYTHON_OP_LIB("contrib_boosted_trees_stats_accumulator_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/boosted_trees/python/ops/gen_stats_accumulator_ops.py) GENERATE_PYTHON_OP_LIB("contrib_cudnn_rnn_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/cudnn_rnn/ops/gen_cudnn_rnn_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_data_prefetching_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/data/python/ops/gen_prefetching_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_clustering_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/factorization/python/ops/gen_clustering_ops.py) GENERATE_PYTHON_OP_LIB("contrib_factorization_factorization_ops" @@ -772,6 +789,10 @@ GENERATE_PYTHON_OP_LIB("contrib_input_pipeline_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/input_pipeline/ops/gen_input_pipeline_ops.py) GENERATE_PYTHON_OP_LIB("contrib_image_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_distort_image_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_distort_image_ops.py) +GENERATE_PYTHON_OP_LIB("contrib_image_sirds_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/image/ops/gen_single_image_random_dot_stereograms_ops.py) GENERATE_PYTHON_OP_LIB("contrib_layers_sparse_feature_cross_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/layers/ops/gen_sparse_feature_cross_op.py) GENERATE_PYTHON_OP_LIB("contrib_memory_stats_ops" @@ -804,6 +825,8 @@ GENERATE_PYTHON_OP_LIB("stateless_random_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/stateless/gen_stateless_random_ops.py) GENERATE_PYTHON_OP_LIB("debug_ops" DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/python/debug/ops/gen_debug_ops.py) +GENERATE_PYTHON_OP_LIB("summary_ops" + DESTINATION ${CMAKE_CURRENT_BINARY_DIR}/tf_python/tensorflow/contrib/summary/gen_summary_ops.py) add_custom_target(tf_python_ops SOURCES ${tf_python_ops_generated_files} ${PYTHON_PROTO_GENFILES}) add_dependencies(tf_python_ops tf_python_op_gen_main) @@ -838,6 +861,7 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.h" "${tensorflow_source_dir}/tensorflow/core/profiler/internal/print_model_analysis.cc" "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe.h" + "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tensor.cc" "${tensorflow_source_dir}/tensorflow/python/eager/pywrap_tfe_src.cc" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.h" "${tensorflow_source_dir}/tensorflow/python/client/tf_session_helper.cc" @@ -865,6 +889,8 @@ set (pywrap_tensorflow_internal_src "${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_writer.cc" "${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.h" "${tensorflow_source_dir}/tensorflow/python/util/kernel_registry.cc" + "${tensorflow_source_dir}/tensorflow/python/util/util.h" + "${tensorflow_source_dir}/tensorflow/python/util/util.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/ops.cc" "${tensorflow_source_dir}/tensorflow/cc/framework/scope.cc" "${CMAKE_CURRENT_BINARY_DIR}/pywrap_tensorflow_internal.cc" @@ -1192,4 +1218,3 @@ else() WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tf_python) endif(${tensorflow_ENABLE_GPU}) endif(${tensorflow_TF_NIGHTLY}) - diff --git a/tensorflow/contrib/cmake/tf_shared_lib.cmake b/tensorflow/contrib/cmake/tf_shared_lib.cmake index 9385ac52e903e1f0f2436066f573af5359c46770..5b685c0a39bb4413a2ade6c9bb77fb5e5e313c66 100644 --- a/tensorflow/contrib/cmake/tf_shared_lib.cmake +++ b/tensorflow/contrib/cmake/tf_shared_lib.cmake @@ -27,6 +27,7 @@ if(WIN32) $ $ $ + $ $ $ $ @@ -63,6 +64,7 @@ add_library(tensorflow SHARED $ $ $ + $ $ $ $ @@ -92,3 +94,46 @@ endif() if(WIN32) add_dependencies(tensorflow tensorflow_static) endif(WIN32) + +install(TARGETS tensorflow + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) + +# install necessary headers +# tensorflow headers +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/cc/ + DESTINATION include/tensorflow/cc + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/cc/ + DESTINATION include/tensorflow/cc + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/core/ + DESTINATION include/tensorflow/core + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/tensorflow/core/ + DESTINATION include/tensorflow/core + FILES_MATCHING PATTERN "*.h") +install(DIRECTORY ${tensorflow_source_dir}/tensorflow/stream_executor/ + DESTINATION include/tensorflow/stream_executor + FILES_MATCHING PATTERN "*.h") +# google protobuf headers +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src/google/ + DESTINATION include/google + FILES_MATCHING PATTERN "*.h") +# nsync headers +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/ + DESTINATION include/external/nsync + FILES_MATCHING PATTERN "*.h") +# Eigen directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/Eigen/ + DESTINATION include/Eigen) +# external directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/external/eigen_archive/ + DESTINATION include/external/eigen_archive) +# third_party eigen directory +install(DIRECTORY ${tensorflow_source_dir}/third_party/eigen3/ + DESTINATION include/third_party/eigen3) +# unsupported Eigen directory +install(DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/eigen/src/eigen/unsupported/Eigen/ + DESTINATION include/unsupported/Eigen) diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index d836428d9e8885dee83bd195d5686a220ea66e3f..5d6ba9ca8d85e9a2d19b7f3e488822a8f21c6821 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -152,6 +152,7 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/python/training/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py" "${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py" @@ -178,28 +179,21 @@ if (tensorflow_BUILD_PYTHON_TESTS) # exclude the ones we don't want set(tf_test_src_py_exclude - # Python source line inspection tests are flaky on Windows (b/36375074). - "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" - "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" - # Windows does not have the curses library and uses readline. - "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" - # TFDBG grpc:// mode is not yet available on Windows. - "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" - "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" - # generally not working + # Not a test. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/resource_variable_ops_test.py" + # Flaky because of port collisions. + "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" + # generally not working "${tensorflow_source_dir}/tensorflow/python/profiler/pprof_profiler_test.py" # flaky test "${tensorflow_source_dir}/tensorflow/python/profiler/internal/run_metadata_test.py" + # Fails because uses data dependencies with bazel "${tensorflow_source_dir}/tensorflow/python/saved_model/saved_model_test.py" # requires scipy "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler_test.py" - # flaky tests + # Takes very long to run without sharding (defined in bazel build file). "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cwise_ops_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/tfprof/python/tools/tfprof/internal/run_metadata_test.py" # Loading resources in contrib doesn't seem to work on Windows "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py" "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py" @@ -212,41 +206,65 @@ if (tensorflow_BUILD_PYTHON_TESTS) if (WIN32) set(tf_test_src_py_exclude ${tf_test_src_py_exclude} - # generally excluded - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/__init__.py" - # TODO: failing tests. # Nothing critical in here but should get this list down to [] # The failing list is grouped by failure source + # Python source line inspection tests are flaky on Windows (b/36375074). + "${tensorflow_source_dir}/tensorflow/python/debug/cli/analyzer_cli_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/cli/profile_analyzer_cli_test.py" + # Windows does not have the curses library and uses readline. + "${tensorflow_source_dir}/tensorflow/python/debug/cli/curses_ui_test.py" + # TFDBG grpc:// mode is not yet available on Windows. + "${tensorflow_source_dir}/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" # stl on windows handles overflows different "${tensorflow_source_dir}/tensorflow/python/kernel_tests/as_string_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/string_to_number_op_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/clip_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/tensor_array_ops_test.py" # Needs portpicker. - # Matrix_set_diag failing on GPU on windows. - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cholesky_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/diag_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_ops_test.py" - # misc - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reshape_op_test.py" - "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" - "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # Depends on gemmlowp -> pthread. + # Numerical issues, calculations off. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py" + # Float division by zero + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py" + # Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/batch_matmul_op_test.py" + # Flaky because of local cluster creation. + "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_grpc_test.py" + "${tensorflow_source_dir}tensorflow/python/training/localhost_cluster_performance_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" + # Type error in testRemoteIteratorUsingRemoteCallOpDirectSessionGPUCPU. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/self_adjoint_eig_op_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py" + # IteratorGetMax OutOfRangeError + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py" + # Depends on gemmlowp -> pthread + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/neon_depthwise_conv_op_test.py" # int32/int64 mixup + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/cast_op_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/variable_scope_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/functional_ops_test.py" "${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py" + # Windows file management related issues. + "${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # training tests "${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix. - "${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows. "${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename. - "${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker. "${tensorflow_source_dir}/tensorflow/python/training/server_lib_test.py" # Test occasionally deadlocks. - + "${tensorflow_source_dir}/tensorflow/python/debug/lib/session_debug_multi_gpu_test.py" # Fails on multiple GPUs. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py" # numerical issues + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/linalg_grad_test.py" # cudaSolver handle creation fails. "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + # Dataset tests + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on windows + "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/iterator_ops_cluster_test.py" # Broken tensorboard test due to cmake issues. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py" # Needs portpicker "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py" # b/65430561 # tensor_forest tests (also note that we exclude the hybrid tests for now) @@ -255,8 +273,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/kernel_tests/scatter_add_ndim_op_test.py" # Bad placement. "${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/topn_test.py" # Results inaccurate "${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support - # Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug. - "${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows. # Dask.Dataframe bugs on Window Build "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py" @@ -265,39 +281,19 @@ if (tensorflow_BUILD_PYTHON_TESTS) # Need extra build "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py" "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/depthtospace_op_test.py" # QuantizeV2 + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/spacetodepth_op_test.py" # QuantizeV2 # Windows Path "${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py" - # Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071 - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py" - # Scipy needed - "${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py" + "${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/kmeans_test.py" "${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py" - # Failing with TF 1.3 (TODO) - "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py" + # Numpy upgrade needed? "${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py" # Test should only be run manually "${tensorflow_source_dir}/tensorflow/python/kernel_tests/reduction_ops_test_big.py" + "${tensorflow_source_dir}/tensorflow/python/kernel_tests/svd_op_test.py" ) endif() list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude}) diff --git a/tensorflow/contrib/cmake/tf_tools.cmake b/tensorflow/contrib/cmake/tf_tools.cmake index 6ef95989630a39eaedaddda68f7da709e7d9ab03..dc1c3b757b5d261c1ae6eaa53651483ff949949c 100644 --- a/tensorflow/contrib/cmake/tf_tools.cmake +++ b/tensorflow/contrib/cmake/tf_tools.cmake @@ -147,3 +147,8 @@ target_link_libraries(${benchmark_model} PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS ${transform_graph} ${summarize_graph} ${compare_graphs} ${benchmark_model} + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) diff --git a/tensorflow/contrib/cmake/tf_tutorials.cmake b/tensorflow/contrib/cmake/tf_tutorials.cmake index 858e7dda92e9e9f456d5fc56b563b2e3ec998520..e63fccc1810b348e543159681a73e7a9c1422c01 100644 --- a/tensorflow/contrib/cmake/tf_tutorials.cmake +++ b/tensorflow/contrib/cmake/tf_tutorials.cmake @@ -34,3 +34,8 @@ target_link_libraries(tf_tutorials_example_trainer PUBLIC ${tf_core_gpu_kernels_lib} ${tensorflow_EXTERNAL_LIBRARIES} ) + +install(TARGETS tf_tutorials_example_trainer + RUNTIME DESTINATION bin + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib) diff --git a/tensorflow/contrib/cmake/tools/create_def_file.py b/tensorflow/contrib/cmake/tools/create_def_file.py index 240a3fd4422ea8ba59dacc022aca7239f2ed9da1..f67698eb99a38eae307b52e55de748a67b798cbd 100644 --- a/tensorflow/contrib/cmake/tools/create_def_file.py +++ b/tensorflow/contrib/cmake/tools/create_def_file.py @@ -69,7 +69,7 @@ INCLUDE_RE = re.compile(r"^(TF_\w*)$|" # We want to identify data members explicitly in the DEF file, so that no one # can implicitly link against the DLL if they use one of the variables exported # from the DLL and the header they use does not decorate the symbol with -# __declspec(dllimport). It is easier to detect what a data symbol does +# __declspec(dllimport). It is easier to detect what a data symbol does # NOT look like, so doing it with the below regex. DATA_EXCLUDE_RE = re.compile(r"[)(]|" r"vftable|" @@ -77,7 +77,7 @@ DATA_EXCLUDE_RE = re.compile(r"[)(]|" r"vcall|" r"RTTI|" r"protobuf::internal::ExplicitlyConstructed") - + def get_args(): """Parse command line.""" filename_list = lambda x: x.split(";") diff --git a/tensorflow/contrib/compiler/jit_test.py b/tensorflow/contrib/compiler/jit_test.py index 94aff13a49f5380d5804e190b33613fd42dcaebc..2108e42bce4eba1eed158fe85888f1699a69ba7e 100644 --- a/tensorflow/contrib/compiler/jit_test.py +++ b/tensorflow/contrib/compiler/jit_test.py @@ -173,12 +173,12 @@ class CompilationEnabledInGradientTest(test.TestCase): def testCompilationInGradient(self): with self.test_session(): - x = constant_op.constant(3) - y_nc = math_ops.add(x, x, name="not_compiled") + x = constant_op.constant([[3]]) + y_nc = math_ops.matmul(x, x, name="not_compiled") with jit.experimental_jit_scope(): - y_c = math_ops.add(y_nc, y_nc, name="compiled") + y_c = math_ops.matmul(y_nc, y_nc, name="compiled") x_grads = gradients.gradients([y_c], [x])[0] - operations = x_grads.graph.get_operations() + operations = x.graph.get_operations() c_grad_ops = [ op for op in operations if "gradients/compiled" in op.name] nc_grad_ops = [ @@ -191,19 +191,19 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.assertRaisesRegexp(ValueError, "No attr named"): ncg.get_attr("_XlaCompile") - # d/dx (4 * x) - self.assertAllClose(4, x_grads.eval()) + # d/dx (x ** 4) = 4 * (x ** 3) + self.assertAllClose([[108]], x_grads.eval()) def testCompilationGradientScopeNames(self): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) @@ -220,12 +220,12 @@ class CompilationEnabledInGradientTest(test.TestCase): with self.test_session(graph=ops.Graph()): with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 0 - a1 = constant_op.constant(1) - a1t = a1 + a1 + a1 = constant_op.constant([[1]]) + a1t = math_ops.matmul(a1, a1) with jit.experimental_jit_scope(True, separate_compiled_gradients=True): # XlaScope 1 - a2 = constant_op.constant(1) - a2t = a2 + a2 + a2 = constant_op.constant([[1]]) + a2t = math_ops.matmul(a2, a2) self.assertEqual(b"jit_scope_0", a1.op.get_attr("_XlaScope")) self.assertEqual(b"jit_scope_1", a2.op.get_attr("_XlaScope")) diff --git a/tensorflow/contrib/cudnn_rnn/BUILD b/tensorflow/contrib/cudnn_rnn/BUILD index d4214587cd1a0fa684710d37083028f9af0425d9..d6d53d521b2024abf50cfbfec96a6e0dc538ed03 100644 --- a/tensorflow/contrib/cudnn_rnn/BUILD +++ b/tensorflow/contrib/cudnn_rnn/BUILD @@ -36,6 +36,7 @@ tf_kernel_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:stream_executor", "//tensorflow/core/kernels:bounds_check_lib", "//third_party/eigen3", ], @@ -54,7 +55,7 @@ tf_gen_op_wrapper_py( ) tf_custom_op_py_library( - name = "cudnn_rnn_py", + name = "cudnn_rnn_ops_py", srcs = [ "__init__.py", "python/ops/cudnn_rnn_ops.py", @@ -70,14 +71,57 @@ tf_custom_op_py_library( visibility = ["//visibility:public"], deps = [ ":cudnn_rnn_ops", + "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:common_shapes", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_seed", + "//tensorflow/python:rnn_cell", "//tensorflow/python:state_ops", "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + ], +) + +tf_custom_op_py_library( + name = "cudnn_rnn_py", + srcs = [ + "__init__.py", + "python/layers/__init__.py", + "python/layers/cudnn_rnn.py", + ], + dso = [ + ":python/ops/_cudnn_rnn_ops.so", + ], + kernels = [ + ":cudnn_rnn_kernels", + ":cudnn_rnn_ops_op_lib", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":cudnn_rnn_ops", + ":cudnn_rnn_ops_py", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", ], ) @@ -85,6 +129,34 @@ cuda_py_test( name = "cudnn_rnn_ops_test", size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_test.py"], + additional_deps = [ + ":cudnn_rnn_ops_py", + "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/python/ops/losses:losses", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], + shard_count = 6, + tags = [ + "manual", + "requires_cudnn5", + ], +) + +cuda_py_test( + name = "cudnn_rnn_test", + size = "enormous", + srcs = ["python/kernel_tests/cudnn_rnn_test.py"], additional_deps = [ ":cudnn_rnn_py", "//tensorflow/core:protos_all_py", @@ -114,7 +186,7 @@ cuda_py_test( size = "large", srcs = ["python/kernel_tests/cudnn_rnn_ops_benchmark.py"], additional_deps = [ - ":cudnn_rnn_py", + ":cudnn_rnn_ops_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client", diff --git a/tensorflow/contrib/cudnn_rnn/__init__.py b/tensorflow/contrib/cudnn_rnn/__init__.py index 87ba834770d8f707c5364ed7bb8db4aaaa21f286..1f7efad71fb04cd754eae8ce170e696baa4d7fc3 100644 --- a/tensorflow/contrib/cudnn_rnn/__init__.py +++ b/tensorflow/contrib/cudnn_rnn/__init__.py @@ -29,14 +29,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers import * +# pylint: enable=unused-import,wildcard-import from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleGRUCell from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnCompatibleLSTMCell -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRU from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnGRUSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTM from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnLSTMSaveable -from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNRelu from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNReluSaveable from tensorflow.contrib.cudnn_rnn.python.ops.cudnn_rnn_ops import CudnnRNNTanhSaveable @@ -56,4 +58,4 @@ _allowed_symbols = [ "CudnnRNNTanhSaveable", ] -remove_undocumented(__name__) +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc index 55fce0a916c9b057234d11d475b56322ce1e29d2..5d5f593d016a3bb9f7b5ea8f5cd40c29268dc4f5 100644 --- a/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/kernels/cudnn_rnn_ops.cc @@ -577,6 +577,7 @@ class CudnnRNNParamsSizeOp : public CudnnRNNKernelCommon { .TypeConstraint("S"), \ CudnnRNNParamsSizeOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -711,6 +712,7 @@ class CudnnRNNParamsToCanonical : public CudnnRNNKernelCommon { .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNParamsToCanonical); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -757,7 +759,9 @@ class CudnnRNNCanonicalToParams : public CudnnRNNKernelCommon { .HostMemory("input_size") \ .TypeConstraint("T"), \ CudnnRNNCanonicalToParams); -TF_CALL_float(REGISTER_GPU) TF_CALL_double(REGISTER_GPU); +TF_CALL_half(REGISTER_GPU); +TF_CALL_float(REGISTER_GPU); +TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU // Run the forward operation of the RNN model. @@ -906,6 +910,7 @@ class CudnnRNNForwardOp : public CudnnRNNKernelCommon { Name("CudnnRNN").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNForwardOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU @@ -1125,6 +1130,7 @@ class CudnnRNNBackwardOp : public CudnnRNNKernelCommon { Name("CudnnRNNBackprop").Device(DEVICE_GPU).TypeConstraint("T"), \ CudnnRNNBackwardOp); +TF_CALL_half(REGISTER_GPU); TF_CALL_float(REGISTER_GPU); TF_CALL_double(REGISTER_GPU); #undef REGISTER_GPU diff --git a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc index 2b297282b264a3777e0a981a1ecccabb0a3a2c4e..9e41e67857101534e8bfef8d5d0b8a45ed8f1f76 100644 --- a/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc +++ b/tensorflow/contrib/cudnn_rnn/ops/cudnn_rnn_ops.cc @@ -75,7 +75,7 @@ REGISTER_OP("CudnnRNNParamsSize") .Input("num_layers: int32") .Input("num_units: int32") .Input("input_size: int32") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("S: {int32, int64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) @@ -130,7 +130,7 @@ REGISTER_OP("CudnnRNN") .Output("output_h: T") .Output("output_c: T") .Output("reserve_space: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) .Attr(kRNNDirectionAttrs) @@ -190,7 +190,7 @@ REGISTER_OP("CudnnRNNBackprop") .Output("input_h_backprop: T") .Output("input_c_backprop: T") .Output("params_backprop: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) .Attr(kRNNDirectionAttrs) @@ -236,7 +236,7 @@ REGISTER_OP("CudnnRNNParamsToCanonical") .Input("params: T") .Output("weights: num_params * T") .Output("biases: num_params * T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("num_params: int") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) @@ -279,7 +279,7 @@ REGISTER_OP("CudnnRNNCanonicalToParams") .Input("weights: num_params * T") .Input("biases: num_params * T") .Output("params: T") - .Attr("T: {float32, float64}") + .Attr("T: {float16, float32, float64}") .Attr("num_params: int") .Attr(kRNNModeAttrs) .Attr(kRNNInputModeAttrs) diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e65394cba07574ed49398981f1cbd8bcb402e24f --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py @@ -0,0 +1,1223 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cudnn RNN models.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import collections +import itertools +import os +import sys +import unittest + +import numpy as np + +from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import rnn as contrib_rnn_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.framework.test_util import TensorFlowTestCase +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_nn_ops +from tensorflow.python.ops import gradients_impl as gradients +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn as rnn_lib +from tensorflow.python.ops import rnn_cell_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import googletest +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import saver as saver_lib + + +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION + +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + + +class CudnnTestModel(object): + """Model with convenient APIs for easier building and running test graph. + + The graph built is used by all tests below to avoid repeatedly building + similar test graphs. + """ + + def __init__(self, + rnn_mode, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + dtype=dtypes.float32, + training=False, + seed=None, + kernel_initializer=None, + bias_initializer=None): + if dtype not in (dtypes.float16, dtypes.float32, dtypes.float64): + raise ValueError("Invalid dtype: %s" % dtype) + self._dtype = dtype + + self._inputs = array_ops.placeholder( + dtype=dtype, shape=[None, None, input_size], name="inputs") + h = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="h") + c = array_ops.placeholder( + dtype=dtype, shape=[None, None, num_units], name="c") + if rnn_mode == CUDNN_LSTM: + model_fn = cudnn_rnn.CudnnLSTM + self._initial_state = (h, c) + elif rnn_mode == CUDNN_GRU: + model_fn = cudnn_rnn.CudnnGRU + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_TANH: + model_fn = cudnn_rnn.CudnnRNNTanh + self._initial_state = (h,) + elif rnn_mode == CUDNN_RNN_RELU: + model_fn = cudnn_rnn.CudnnRNNRelu + self._initial_state = (h,) + else: + raise ValueError("Invalid rnn_mode: %s" % rnn_mode) + self._rnn = model_fn( + num_layers, + num_units, + direction=direction, + dropout=dropout, + dtype=dtype, + seed=seed, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer) + self._rnn.build([None, None, input_size]) + + self._outputs, self._output_state = self._rnn( + self._inputs, initial_state=self._initial_state, training=training) + + def _AddUp(self, outputs, output_state): + total = math_ops.reduce_sum(outputs) + for s in output_state: + total += math_ops.reduce_sum(s) + return total + + @property + def inputs(self): + return self._inputs + + @property + def initial_state(self): + return self._initial_state + + @property + def outputs(self): + return self._outputs + + @property + def output_state(self): + return self._output_state + + @property + def rnn(self): + return self._rnn + + @property + def total_sum(self): + return self._AddUp(self.outputs, self.output_state) + + def SynthesizeInput(self, seq_length, batch_size, seed=1234): + """Synthesizes input and initial state values for testing.""" + np.random.seed(seed) + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + input_size = self._rnn.input_size + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + inputs = np.random.randn(seq_length, batch_size, + input_size).astype(np_dtype) + input_h = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.random.randn(num_layers * dir_count, batch_size, + num_units).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return inputs, initial_state + + def ZeroState(self, batch_size): + num_layers = self._rnn.num_layers + dir_count = self._rnn.num_dirs + num_units = self._rnn.num_units + + np_dtype = np.float32 if self._dtype == dtypes.float32 else np.float64 + input_h = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + if self._rnn.rnn_mode == CUDNN_LSTM: + input_c = np.zeros((num_layers * dir_count, batch_size, + num_units)).astype(np_dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + return initial_state + + def FProp(self, inputs_t, initial_state_t, training): + """Builds additional subgraph with given inputs and state. + + Args: + inputs_t: a tensor. + initial_state_t: a tensor. + training: boolean, true if training mode. + Returns: + A tensor of the forward pass output of the model. + """ + outputs, output_state = self._rnn( + inputs_t, initial_state=initial_state_t, training=training) + return self._AddUp(outputs, output_state) + + def Feed(self, sess, inputs, initial_state=None, return_sum=True): + """Runs graph with given inputs and initial state.""" + batch_size = inputs.shape[1] + if initial_state is None: + initial_state = self.ZeroState(batch_size) + if return_sum: + return sess.run( + self.total_sum, + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + else: + return sess.run( + [self.outputs, self.output_state], + feed_dict={self.inputs: inputs, + self.initial_state: initial_state}) + + +def _CreateCudnnCompatibleCanonicalRNN(rnn, inputs, is_bidi=False, scope=None): + mode = rnn.rnn_mode + num_units = rnn.num_units + num_layers = rnn.num_layers + + # To reuse cuDNN-trained models, must use cudnn compatible rnn cells. + if mode == CUDNN_LSTM: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleLSTMCell(num_units) + elif mode == CUDNN_GRU: + single_cell = lambda: cudnn_rnn_ops.CudnnCompatibleGRUCell(num_units) + elif mode == CUDNN_RNN_TANH: + single_cell = (lambda: rnn_cell_impl.BasicRNNCell(num_units, math_ops.tanh)) + elif mode == CUDNN_RNN_RELU: + single_cell = ( + lambda: rnn_cell_impl.BasicRNNCell(num_units, gen_nn_ops.relu)) + else: + raise ValueError("%s is not supported!" % mode) + + if not is_bidi: + cell = rnn_cell_impl.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + return rnn_lib.dynamic_rnn( + cell, inputs, dtype=dtypes.float32, time_major=True, scope=scope) + else: + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + (outputs, output_state_fw, + output_state_bw) = contrib_rnn_lib.stack_bidirectional_dynamic_rnn( + cells_fw, + cells_bw, + inputs, + dtype=dtypes.float32, + time_major=True, + scope=scope) + return outputs, (output_state_fw, output_state_bw) + + +class CudnnRNNTestBasic(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testLayerBasic(self): + num_layers = 4 + num_units = 2 + batch_size = 8 + direction = CUDNN_RNN_UNIDIRECTION + dir_count = 1 + + with vs.variable_scope("main"): + kernel_initializer = init_ops.constant_initializer(0.) + bias_initializer = init_ops.constant_initializer(0.) + inputs = random_ops.random_uniform([ + num_layers * dir_count, batch_size, num_units], dtype=dtypes.float32) + + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Build the layer + outputs1, _ = lstm(inputs) + # Reuse the layer + outputs2, _ = lstm(inputs) + + total_sum1 = math_ops.reduce_sum(outputs1) + total_sum2 = math_ops.reduce_sum(outputs2) + + with vs.variable_scope("main", reuse=True): + lstm = cudnn_rnn.CudnnLSTM(num_layers, num_units, + direction=direction, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + name="awesome_lstm") + + # Reuse the layer + outputs3, _ = lstm(inputs) + total_sum3 = math_ops.reduce_sum(outputs3) + + self.assertEqual(1, len(variables.trainable_variables())) + self.assertEqual(1, len(ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS))) + self.assertEqual("main/awesome_lstm/opaque_kernel", + variables.trainable_variables()[0].op.name) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + (total_sum1_v, total_sum2_v, total_sum3_v) = sess.run( + [total_sum1, total_sum2, total_sum3]) + self.assertEqual(0, total_sum1_v) + self.assertEqual(0, total_sum2_v) + self.assertEqual(0, total_sum3_v) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestSaveRestore(TensorFlowTestCase): + + def _CompareWeights(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + for lw, rw in zip(lhs, rhs): + self.assertAllEqual(lw, rw) + + def _CompareBiases(self, lhs, rhs, rnn_mode, num_layers, direction): + self.assertEqual(len(lhs), len(rhs)) + if rnn_mode == CUDNN_LSTM: + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_GRU: + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + else: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + num_dirs = 1 if direction == CUDNN_RNN_UNIDIRECTION else 2 + num_params_per_layer *= num_dirs + self.assertEqual(num_params_per_layer * num_layers, len(lhs)) + + for i in range(num_layers): + layer_lhs = lhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + layer_rhs = rhs[i * num_params_per_layer: (i+1) * num_params_per_layer] + if direction == CUDNN_RNN_UNIDIRECTION: + self._CompareSingleLayerBiases(layer_lhs, layer_rhs) + else: + size = len(layer_lhs) + fw_lhs, bw_lhs = layer_lhs[:size//2], layer_lhs[size//2:] + fw_rhs, bw_rhs = layer_rhs[:size//2], layer_rhs[size//2:] + self._CompareSingleLayerBiases(fw_lhs, fw_rhs) + self._CompareSingleLayerBiases(bw_lhs, bw_rhs) + + def _CompareSingleLayerBiases(self, lhs, rhs): + self.assertEqual(len(lhs), len(rhs)) + + lf_lhs, rt_lhs = lhs[:len(lhs)//2], lhs[len(lhs)//2:] + lf_rhs, rt_rhs = rhs[:len(rhs)//2], rhs[len(rhs)//2:] + self.assertEqual(len(lf_lhs), len(rt_lhs)) + self.assertEqual(len(lf_rhs), len(rt_rhs)) + + sum_lhs, sum_rhs = [], [] + for lf, rt in zip(lf_lhs, rt_lhs): + sum_lhs.append(lf + rt) + for lf, rt in zip(lf_rhs, rt_rhs): + sum_rhs.append(lf + rt) + self.assertEqual(len(sum_lhs), len(sum_rhs)) + for lf, rt in zip(sum_lhs, sum_rhs): + self.assertAllEqual(lf, rt) + + def _TestSaveRestoreVariable(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + rnn = model.rnn + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test") + saver = saver_lib.Saver() + weights, biases = model.rnn.saveable._OpaqueParamsToCanonical() + opaque_params = rnn.trainable_variables[0] + # CudnnTestModel() creates CudnnOpaqueParamsSaveable that helps saver save + # Cudnn vars in canonical format. + reset_op = state_ops.assign( + opaque_params, + array_ops.zeros(array_ops.shape(opaque_params), dtype=dtype)) + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + weights_v, biases_v = sess.run([weights, biases]) + + # Reset opaque param + sess.run(reset_op) + saver.restore(sess, save_path) + weights_v_restored, biases_v_restored = sess.run([weights, biases]) + + self._CompareWeights(weights_v, weights_v_restored) + self._CompareBiases(biases_v, biases_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreTwoVariables(self, rnn_mode, direction, dtype): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default() as g: + random_seed.set_random_seed(1234) + with vs.variable_scope("m1"): + model1 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + with vs.variable_scope("m2"): + model2 = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype) + opaque_params = (model1.rnn.trainable_variables[0], + model2.rnn.trainable_variables[0]) + weights1, biases1 = model1.rnn.saveable._OpaqueParamsToCanonical() + weights2, biases2 = model2.rnn.saveable._OpaqueParamsToCanonical() + reset_params = [ + state_ops.assign(params, + array_ops.zeros_like(params, dtype=dtype)) + for params in opaque_params + ] + reset_op = control_flow_ops.group(*reset_params) + save_path = os.path.join(self.get_temp_dir(), + "save-restore-variable-test2") + saver = saver_lib.Saver() + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + weights1_v, biases1_v = sess.run([weights1, biases1]) + weights2_v, biases2_v = sess.run([weights2, biases2]) + + sess.run(reset_op) + saver.restore(sess, save_path) + weights1_v_restored, biases1_v_restored = sess.run([weights1, biases1]) + weights2_v_restored, biases2_v_restored = sess.run([weights2, biases2]) + + self._CompareWeights(weights1_v, weights1_v_restored) + self._CompareWeights(weights2_v, weights2_v_restored) + self._CompareBiases(biases1_v, biases1_v_restored, rnn_mode, num_layers, + direction) + self._CompareBiases(biases2_v, biases2_v_restored, rnn_mode, num_layers, + direction) + + def _TestSaveRestoreOutput(self, rnn_mode, direction, dtype): + with ops.Graph().as_default() as g: + num_layers = 2 + num_units = 7 + input_size = 7 + seq_length = 8 + batch_size = 4 + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + + save_path = os.path.join(self.get_temp_dir(), "save-restore-output-test") + saver = saver_lib.Saver() + + # Only one opaque var in a cudnn layer. + assert len(rnn.trainable_variables) == 1 + reset_params = state_ops.assign( + rnn.trainable_variables[0], + array_ops.zeros( + array_ops.shape(rnn.trainable_variables[0]), dtype=dtype)) + + # Passing graph explicitly, otherwise an old sess would be reused. + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + inputs, initial_state = model.SynthesizeInput(seq_length, batch_size) + total_sum_v = model.Feed(sess, inputs, initial_state) + val = saver.save(sess, save_path) + self.assertEqual(save_path, val) + + sess.run(reset_params) + saver.restore(sess, save_path) + total_sum_v_restored = model.Feed(sess, inputs, initial_state) + self.assertAllClose(total_sum_v, total_sum_v_restored, atol=1e-5) + + def _TestSaveRestoreHelper(self, rnn_mode): + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float16, dtypes.float32, dtypes.float64] + for direction, dtype in itertools.product(directions, dtype_list): + self._TestSaveRestoreVariable(rnn_mode, direction, dtype) + self._TestSaveRestoreTwoVariables(rnn_mode, direction, dtype) + self._TestSaveRestoreOutput(rnn_mode, direction, dtype) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRepeatedlyCreateCustomSaveable(self): + input_size = 3 + num_layers = 2 + num_units = 7 + with ops.Graph().as_default(): + random_seed.set_random_seed(1234) + model = CudnnTestModel( + CUDNN_LSTM, + num_layers, + num_units, + input_size, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32) + with self.assertRaisesRegexp(RuntimeError, + "Cudnn saveable already created"): + model.rnn._create_saveable() + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreLSTM(self): + self._TestSaveRestoreHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreGRU(self): + self._TestSaveRestoreHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNTanh(self): + self._TestSaveRestoreHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSaveRestoreRNNRelu(self): + self._TestSaveRestoreHelper(CUDNN_RNN_RELU) + + +# TODO(jamesqin): Transform to parameterized test after it is included in the +# TF open source codebase. +class CudnnRNNTestCompatibleRNNCells(TensorFlowTestCase): + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleLSTM(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_LSTM) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleGRU(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_GRU) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNTanh(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_TANH) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testCudnnCompatibleRNNRelu(self): + self._TestCudnnCompatibleRnnCellsHelper(CUDNN_RNN_RELU) + + def _TestCudnnCompatibleRnnCellsHelper(self, rnn_mode): + configs = [ + { + "num_layers": 1, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 2, + "seq_length": 8, + "num_units": 4, + "input_size": 8, + "batch_size": 16, + }, + { + "num_layers": 2, + "seq_length": 3, + "num_units": 4, + "input_size": 5, + "batch_size": 6, + }, + { + "num_layers": 1, + "seq_length": 2, + "num_units": 2, + "input_size": 4, + "batch_size": 1, + }, + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + for cfg, direction in zip(configs, directions): + self._TestCudnnCompatibleRnnCells(cfg["num_layers"], cfg["seq_length"], + cfg["num_units"], cfg["input_size"], + cfg["batch_size"], rnn_mode, direction) + + def _TestCudnnCompatibleRnnCells(self, num_layers, seq_length, num_units, + input_size, batch_size, rnn_mode, direction): + dtype = dtypes.float32 + # Train graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=True) + target_output = array_ops.placeholder(dtype=dtype) + loss_op = losses.log_loss( + labels=target_output, predictions=model.total_sum) + optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=1e-2) + train_op = optimizer.minimize(loss_op) + + saver = saver_lib.Saver() + + # Train Cudnn model + seed = 0 + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + # Train 128 steps + num_steps = 128 + for _ in range(num_steps): + inputs, _ = model.SynthesizeInput(seq_length, batch_size, seed) + targets = np.random.rand() + sess.run( + train_op, + feed_dict={ + model.inputs: inputs, + model.initial_state: model.ZeroState(batch_size), + target_output: targets + }) + seed += 1 + + save_path = os.path.join(self.get_temp_dir(), + ("cudnn-rnn-%s-test" % rnn_mode)) + save_v = saver.save(sess, save_path) + self.assertEqual(save_path, save_v) + + # Cudnn inference graph + with ops.Graph().as_default() as g: + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dtype=dtype, + training=False) + rnn = model.rnn + saver = saver_lib.Saver() + + inference_input = np.random.rand(seq_length, batch_size, + input_size).astype(np.float32) + with self.test_session(use_gpu=True, graph=g) as sess: + sess.run(variables.global_variables_initializer()) + saver.restore(sess, save_path) + + # Cudnn inference + cudnn_outputs_v, cudnn_output_states_v = model.Feed( + sess, inference_input, return_sum=False) + + # Canonical RNN inference graph + with ops.Graph().as_default() as g: + cell_inputs = array_ops.placeholder( + dtype, shape=[seq_length, batch_size, input_size]) + if direction == CUDNN_RNN_UNIDIRECTION: + # outputs is one tensor, states are num_layer tuples, each 2 tensors + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN(rnn, cell_inputs) + if rnn_mode == CUDNN_LSTM: + output_h = array_ops.stack([s.h for s in states]) + output_c = array_ops.stack([s.c for s in states]) + else: + output_state = array_ops.stack([s for s in states]) + else: + # outputs is one tensor. + # states is a tuple of 2 tuples: + # each sub tuple is num_layer tuples, each with 2 tensors. + (outputs, states) = _CreateCudnnCompatibleCanonicalRNN( + rnn, cell_inputs, is_bidi=True) + output_state_fw, output_state_bw = states + if rnn_mode == CUDNN_LSTM: + output_h, output_c = [], [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_h.append(array_ops.stack([s_fw.h, s_bw.h])) + output_c.append(array_ops.stack([s_fw.c, s_bw.c])) + output_h = array_ops.concat(output_h, axis=0) + output_c = array_ops.concat(output_c, axis=0) + else: + output_state = [] + for s_fw, s_bw in zip(output_state_fw, output_state_bw): + output_state.append(array_ops.stack([s_fw, s_bw])) + output_state = array_ops.concat(output_state, axis=0) + saver = saver_lib.Saver() + + with self.test_session(use_gpu=True, graph=g) as sess: + saver.restore(sess, save_path) + + # BlockCell inference + if rnn_mode == CUDNN_LSTM: + outputs_v, output_h_v, output_c_v = sess.run( + [outputs, output_h, output_c], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v) + cudnn_output_h_v, cudnn_output_c_v = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_h_v) + self.assertAllClose(cudnn_output_c_v, output_c_v) + else: + outputs_v, output_state_v = sess.run( + [outputs, output_state], + feed_dict={cell_inputs: inference_input}) + self.assertAllClose(cudnn_outputs_v, outputs_v, atol=2e-5, rtol=2e-5) + (cudnn_output_h_v,) = cudnn_output_states_v + self.assertAllClose(cudnn_output_h_v, output_state_v, atol=2e-5, + rtol=2e-5) + + +class CudnnRNNTestParamsSize(TensorFlowTestCase): + + def _TestOpaqueParamsSize(self, rnn_mode, num_layers, num_units, input_size, + dtype, direction): + logging.info("Testing one lstm param size with config: %s", locals()) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + dtype=dtype, + direction=direction) + rnn = model.rnn + + # Min param size estimate = sum(weights.size) + sum(biases.size) + min_params_size = ( + np.sum(map(np.prod, rnn.canonical_weight_shapes)) + + np.sum([sp[0] for sp in rnn.canonical_bias_shapes])) + + opaque_params = rnn.trainable_variables[0] + with self.test_session(use_gpu=True, graph=ops.get_default_graph()): + variables.global_variables_initializer().run() + opaque_params_size_v = opaque_params.eval().size + self.assertLessEqual(min_params_size, opaque_params_size_v) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testOpaqueParamsSize(self): + test_configs = [ + [4, 200, 200], + [4, 200, 300], + [4, 200, 100], + [1, 100, 200], + [2, 200, 100], + [3, 200, 400], + ] + directions = [CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION] + dtype_list = [dtypes.float16, dtypes.float32, dtypes.float64] + rnns = [CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_RELU, CUDNN_RNN_TANH] + for (rnn, config, dtype, direction) in itertools.product( + rnns, test_configs, dtype_list, directions): + num_layers, num_units, input_size = config + with ops.Graph().as_default(): + self._TestOpaqueParamsSize(rnn, num_layers, num_units, input_size, + dtype, direction) + + +class CudnnRNNTestTraining(TensorFlowTestCase): + + def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1): + """Compute the numeric gradient of y wrt to x. + + Args: + sess: The TF session constructed with a graph containing x and y. + y: A scalar TF Tensor in the graph constructed in sess. + x: A TF Tensor in the graph constructed in sess. + delta: Gradient checker's small perturbation of x[i]. + step: Only compute numerical gradients for a subset of x values. + I.e. dy/dx[i] is computed if i % step == 0. + Returns: + A Tensor of the same shape and dtype as x. If x[i] is not chosen + to compute the numerical gradient dy/x[i], the corresponding + value is set to 0. + """ + + x_data = sess.run(x) + x_size = x_data.size + x_shape = x_data.shape + + numeric_grad = np.zeros(x_size, dtype=x_data.dtype) + + for i in range(0, x_size, step): + x_pos = x_data.copy() + if x_size == 1: + x_pos += delta + else: + x_pos.flat[i] += delta + y_pos_feed_dict = dict([(x.name, x_pos)]) + y_pos = sess.run(y, feed_dict=y_pos_feed_dict) + + x_neg = x_data.copy() + if x_size == 1: + x_neg -= delta + else: + x_neg.flat[i] -= delta + y_neg_feed_dict = dict([(x.name, x_neg)]) + y_neg = sess.run(y, feed_dict=y_neg_feed_dict) + numeric_grad[i] = (y_pos - y_neg) / (2 * delta) + return numeric_grad.reshape(x_shape) + + def _GetShape(self, sess, inputs): + if not isinstance(inputs, collections.Iterable): + return sess.run(array_ops.shape(inputs)) + else: + return sess.run([array_ops.shape(x) for x in inputs]) + + def _GradientCheckFp16(self, sess, y, xs, num_samples, + tolerance=1e-6, delta=1e-4): + """Gradient check for Fp16. + + Fp16 numerical gradients end up being zeros. Use a new way to check + gradients: + + Given multi-variant function: + y = f(x1, x2, ... xn) + delta_y = f(x1 + delta_x1, x2+delta_x2, ..., xn+delta_xn) - + f(x1, x2, ..., xn) + = f'(x1) * delta_x1 + f'(x2) * delta_x2 + .. + f'(xn) * delta_xn + where: + delta_xi are very small disturbance. + f'(xi) is the gradient of y w.r.t xi. + + The gradient check verifies the expected delta_y calculated by the above + equation is close to the actual delta_y. + Args: + sess: tf.Session object. + y: output tensor. + xs: a tensor or a list of input tensors. + num_samples: number of test samples to run. + tolerance: error tolerance. + delta: the order of magnititued of input disturbance to apply to calculate + the output change w.r.t inputs. + """ + sym_grads = self._ComputeSymGrads(sess, y, xs) + xs_shapes = self._GetShape(sess, xs) + + x_vals = [sess.run(x) for x in xs] + for _ in range(num_samples): + delta_xs = [delta * np.random.rand(*shape.tolist()) + for shape in xs_shapes] + + feed_dict = {} + for x, x_val, delta_x in zip(xs, x_vals, delta_xs): + feed_dict[x] = x_val + delta_x + actual_delta_y = (float(sess.run(y, feed_dict=feed_dict)) - + float(sess.run(y))) + + expected_delta_y = 0. + for sym_grad, delta_x in zip(sym_grads, delta_xs): + expected_delta_y += np.dot( + sym_grad.astype(np.float32).flatten(), + delta_x.astype(np.float32).flatten()) + self.assertAllClose(expected_delta_y, actual_delta_y, + atol=tolerance, rtol=tolerance) + + def _GradientCheck(self, sess, y, xs, tolerance=1e-6, delta=1e-4): + sym_grads = self._ComputeSymGrads(sess, y, xs) + + num_grads = [self._ComputeNumericGrad(sess, y, x, delta) for x in xs] + self.assertEqual(len(sym_grads), len(num_grads)) + for sym, num in zip(sym_grads, num_grads): + self.assertFalse(np.any(np.isnan(sym))) + self.assertFalse(np.any(np.isnan(num))) + self.assertAllClose(sym, num, atol=tolerance, rtol=tolerance) + + def _ComputeSymGrads(self, sess, y, xs): + sym_grads_t = gradients.gradients(y, xs) + return sess.run(sym_grads_t) + + def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size, + batch_size, seq_length, dir_count, dropout, dtype, + delta, tolerance): + # Gradient checking runs two forward ops with almost the same input. Need to + # make sure the drop patterns across the two runs are the same. + logging.info("Training test with config: %s", locals()) + old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False)) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True) + + np.random.seed(1234) + random_seed.set_random_seed(5678) + has_input_c = (rnn_mode == CUDNN_LSTM) + direction = (CUDNN_RNN_UNIDIRECTION + if dir_count == 1 else CUDNN_RNN_BIDIRECTION) + model = CudnnTestModel( + rnn_mode, + num_layers, + num_units, + input_size, + direction=direction, + dropout=dropout, + dtype=dtype, + training=True, + bias_initializer=init_ops.random_normal_initializer( + mean=1., dtype=dtype)) + rnn = model.rnn + params = rnn.trainable_variables[0] + + inputs = variables.Variable( + random_ops.random_uniform( + [seq_length, batch_size, input_size], dtype=dtype), + dtype=dtype) + input_h = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + if has_input_c: + input_c = variables.Variable( + random_ops.random_uniform( + [num_layers * dir_count, batch_size, num_units], dtype=dtype), + dtype=dtype) + initial_state = (input_h, input_c) + else: + initial_state = (input_h,) + total_sum = model.FProp(inputs, initial_state, training=True) + + with self.test_session(use_gpu=True, graph=ops.get_default_graph()) as sess: + sess.run(variables.global_variables_initializer()) + all_inputs = [inputs, params] + for s in initial_state: + all_inputs.append(s) + if dtype == dtypes.float16: + self._GradientCheckFp16( + sess, total_sum, all_inputs, + num_samples=FLAGS.grad_check_num_samples, + tolerance=tolerance, delta=delta) + else: + for _ in range(FLAGS.grad_check_num_samples): + # Each time choose a different set of inputs. + sess.run(variables.global_variables_initializer()) + self._GradientCheck( + sess, total_sum, all_inputs, + tolerance=tolerance, delta=delta) + os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state + + def _TestSimpleTrainingHelper(self, rnn_mode, test_configs): + dropouts = [0, 0.5, 1.] + for config, dropout in itertools.product(test_configs, dropouts): + dtype = config.get("dtype", dtypes.float32) + delta = config.get("delta", 1e-4) + tolerance = config.get("tolerance", 1e-6) + dir_count = config.get("dir_count", 1) + shape = config["shape"] + with ops.Graph().as_default(): + self._TestOneSimpleTraining(rnn_mode, shape["num_layers"], + shape["num_units"], shape["input_size"], + shape["batch_size"], shape["seq_length"], + dir_count, dropout, dtype, delta, + tolerance) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTMFp64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTMFp32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingLSTMFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 1e-3, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + { + "dtype": dtypes.float16, + "delta": 1e-2, + "tolerance": 9e-2, + "shape": { + "num_layers": 2, + "num_units": 6, + "input_size": 8, + "batch_size": 6, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_LSTM, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRUFp64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + } + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRUFp32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 4e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingGRUFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 2e-3, + "tolerance": 6e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_GRU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanhFp64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanhFp32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-3, + "tolerance": 5e-3, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNTanhFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 1e-3, + "tolerance": 5e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_TANH, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNReluFp64(self): + test_configs = [ + { + "dtype": dtypes.float64, + "tolerance": 5e-6, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNReluFp32(self): + test_configs = [ + { + "dtype": dtypes.float32, + "delta": 1e-4, + "tolerance": 3e-1, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + @unittest.skipUnless(test.is_built_with_cuda(), + "Test only applicable when running on GPUs") + def testSimpleTrainingRNNReluFp16(self): + test_configs = [ + { + "dtype": dtypes.float16, + "delta": 1e-3, + "tolerance": 7e-2, + "shape": { + "num_layers": 2, + "num_units": 3, + "input_size": 4, + "batch_size": 3, + "seq_length": 4, + }, + }, + ] + self._TestSimpleTrainingHelper(CUDNN_RNN_RELU, test_configs) + + +if __name__ == "__main__": + argv0 = sys.argv[0] + parser = argparse.ArgumentParser() + parser.add_argument( + "--grad_check_num_samples", + type=int, + default=5, + help="Number of samples to run for gradient check.") + FLAGS, unparsed = parser.parse_known_args() + sys.argv = [argv0] + unparsed + googletest.main() diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5feee3d10d14020d63eec0541e5caa37e79f9f57 --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""layers module with higher level CudnnRNN primitives.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import sys + +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.cudnn_rnn.python.layers.cudnn_rnn import * +# pylint: enable=unused-import,wildcard-import diff --git a/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py new file mode 100644 index 0000000000000000000000000000000000000000..37c61a71a3bdac4fadef58ba8c24b853fb3638ef --- /dev/null +++ b/tensorflow/contrib/cudnn_rnn/python/layers/cudnn_rnn.py @@ -0,0 +1,557 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Cudnn RNN operators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base as base_layer +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.platform import tf_logging as logging + + +CUDNN_RNN_UNIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_UNIDIRECTION +CUDNN_RNN_BIDIRECTION = cudnn_rnn_ops.CUDNN_RNN_BIDIRECTION +CUDNN_LSTM = cudnn_rnn_ops.CUDNN_LSTM +CUDNN_GRU = cudnn_rnn_ops.CUDNN_GRU +CUDNN_RNN_RELU = cudnn_rnn_ops.CUDNN_RNN_RELU +CUDNN_RNN_TANH = cudnn_rnn_ops.CUDNN_RNN_TANH + +# Half for cell input, half for hidden states. +CUDNN_LSTM_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_LSTM_PARAMS_PER_LAYER +CUDNN_GRU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_GRU_PARAMS_PER_LAYER +CUDNN_RNN_TANH_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_TANH_PARAMS_PER_LAYER +CUDNN_RNN_RELU_PARAMS_PER_LAYER = cudnn_rnn_ops.CUDNN_RNN_RELU_PARAMS_PER_LAYER + +CUDNN_INPUT_LINEAR_MODE = cudnn_rnn_ops.CUDNN_INPUT_LINEAR_MODE +CUDNN_INPUT_SKIP_MODE = cudnn_rnn_ops.CUDNN_INPUT_SKIP_MODE +CUDNN_INPUT_AUTO_MODE = cudnn_rnn_ops.CUDNN_INPUT_AUTO_MODE + + +__all__ = ["CudnnLSTM", "CudnnGRU", "CudnnRNNTanh", "CudnnRNNRelu"] + + +class _CudnnRNN(base_layer.Layer): + # pylint:disable=line-too-long + """Abstract class for RNN layers with Cudnn implementation. + + Cudnn RNNs have two major differences from other platform-independent RNNs tf + provides: + * Cudnn LSTM and GRU are mathematically different from their tf counterparts. + (e.g. @{tf.contrib.rnn.LSTMBlockCell} and @{tf.nn.rnn_cell.GRUCell}. + * Cudnn-trained checkpoints are not directly compatible with tf RNNs: + * They use a single opaque parameter buffer for the entire (possibly) + multi-layer multi-directional RNN; Whereas tf RNN weights are per-cell and + layer. + * The size and layout of the parameter buffers may change between + CUDA/CuDNN/GPU generations. Because of that, the opaque parameter variable + does not have a static shape and is not partitionable. Instead of using + partitioning to alleviate the PS's traffic load, try building a + multi-tower model and do gradient aggregation locally within the host + before updating the PS. See https://www.tensorflow.org/performance/performance_models#parameter_server_variables + for a detailed performance guide. + + Consequently, if one plans to use Cudnn trained models on both GPU and CPU + for inference and training, one needs to: + * Create a CudnnOpaqueParamsSaveable subclass object to save RNN params in + canonical format. (This is done for you automatically during layer building + process.) + * When not using a Cudnn RNN class, use CudnnCompatibleRNN classes to load the + checkpoints. These classes are platform-independent and perform the same + computation as Cudnn for training and inference. + Similarly, CudnnCompatibleRNN-trained checkpoints can be loaded by CudnnRNN + classes seamlessly. + + Below is a typical workflow(using LSTM as an example): + for detailed performance guide. + + # Use Cudnn-trained checkpoints with CudnnCompatibleRNNs + ```python + with tf.Graph().as_default(): + lstm = CudnnLSTM(num_layers, num_units, direction, ...) + + outputs, output_states = lstm(inputs, initial_states, training=True) + + # If user plans to delay calling the cell with inputs, one can do + # lstm.build(input_shape) + + saver = Saver() + + # training subgraph + ... + + # Once in a while save the model. + saver.save(save_path) + + # Inference subgraph for unidirectional RNN on, e.g., CPU or mobile. + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + + # NOTE: Even if there's only one layer, the cell needs to be wrapped in + # MultiRNNCell. + cell = tf.nn.rnn_cell.MultiRNNCell( + [single_cell() for _ in range(num_layers)]) + + # Leave the scope arg unset. + outputs, final_state = tf.nn.dynamic_rnn(cell, inputs, initial_state, ...) + + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + + # Inference subgraph for bidirectional RNN + with tf.Graph().as_default(): + single_cell = lambda: tf.contrib.cudnn_rnn.CudnnCompatibleLSTM(num_units) + cells_fw = [single_cell() for _ in range(num_layers)] + cells_bw = [single_cell() for _ in range(num_layers)] + + # Leave the scope arg unset. + (outputs, output_state_fw, + output_state_bw) = tf.contrib.rnn.stack_bidirectional_dynamic_rnn( + cells_fw, cells_bw, inputs, ...) + saver = Saver() + + # Create session + sess = ... + + # Restores + saver.restore(sess, save_path) + ``` + """ + # pylint:enable=line-too-long + + # The following are constants defined by subclasses. + # Type of RNN cell. + _rnn_mode = None + # Number of cell weights(or biases) per layer. + _num_params_per_layer = None + # Custom SaveableObject class for the CudnnRNN class. + _saveable_cls = None + + def __init__(self, + num_layers, + num_units, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0., + seed=None, + dtype=dtypes.float32, + kernel_initializer=None, + bias_initializer=None, + name=None): + """Creates a CudnnRNN model from model spec. + + Args: + num_layers: the number of layers for the RNN model. + num_units: the number of units within the RNN model. + input_mode: indicate whether there is a linear projection between the + input and the actual computation before the first layer. It can be + 'linear_input', 'skip_input' or 'auto_select'. + 'linear_input' (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior). + 'skip_input' is only allowed when input_size == num_units; + 'auto_select' implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + direction: the direction model that the model operates. Can be either + 'unidirectional' or 'bidirectional' + dropout: dropout rate, a number between [0, 1]. Dropout is applied on + inputs of each layer. When set to 0, dropout is disabled. + seed: the op seed used for initializing dropout. See @{tf.set_random_seed} + for behavior. + dtype: tf.float16, tf.float32 or tf.float64 + kernel_initializer: starting value to initialize the weight. + bias_initializer: starting value to initialize the bias + (default is all zeros). + name: VariableScope for the created subgraph; defaults to class name. + This only serves the default scope if later no scope is specified when + invoking __call__(). + + Raises: + ValueError: if direction is invalid. Or dtype is not supported. + """ + super(_CudnnRNN, self).__init__(dtype=dtype, name=name) + cudnn_rnn_ops.check_direction(direction) + cudnn_rnn_ops.check_input_mode(input_mode) + + if dtype not in [dtypes.float16, dtypes.float32, dtypes.float64]: + raise ValueError( + "Only support float16, float32, float64, provided %s" % dtype) + # Layer self.dtype is type name, the original DType object is kept here. + self._plain_dtype = dtype + self._num_layers = num_layers + self._num_units = num_units + self._input_mode = input_mode + self._direction = direction + self._dropout = dropout + self._seed = seed + self._kernel_initializer = kernel_initializer + self._bias_initializer = bias_initializer + # Init input_size to None, which will be set after build(). + self._input_size = None + self._saveable = None + + @property + def num_layers(self): + return self._num_layers + + @property + def num_units(self): + return self._num_units + + @property + def input_mode(self): + """Input mode of first layer. + + Indicates whether there is a linear projection between the input and the + actual computation before the first layer. It can be + * 'linear_input': (default) always applies a linear projection of input + onto RNN hidden state. (standard RNN behavior) + * 'skip_input': 'skip_input' is only allowed when input_size == num_units. + * 'auto_select'. implies 'skip_input' when input_size == num_units; + otherwise, it implies 'linear_input'. + + Returns: + 'linear_input', 'skip_input' or 'auto_select'. + """ + return self._input_mode + + @property + def input_size(self): + if not self._input_size: + raise ValueError( + "\'input_size\' is unknown since layer has not been built.") + return self._input_size + + @property + def rnn_mode(self): + """Type of RNN cell used. + + Returns: + `lstm`, `gru`, `rnn_relu` or `rnn_tanh`. + """ + return self._rnn_mode + + @property + def direction(self): + """Returns `unidirectional` or `bidirectional`.""" + return self._direction + + @property + def num_dirs(self): + return 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + + @property + def saveable(self): + return self._saveable + + @property + def canonical_weight_shapes(self): + """Shapes of Cudnn canonical weight tensors.""" + if not self._input_size: + raise RuntimeError( + "%s.canonical_weight_shapes invoked before input shape is known" % + type(self).__name__) + + shapes = [] + for i in range(self._num_layers): + shapes.extend(self._canonical_weight_shape(i)) + return shapes + + @property + def canonical_bias_shapes(self): + """Shapes of Cudnn canonical bias tensors.""" + return self._canonical_bias_shape(0) * self._num_layers + + def _update_trainable_weights(self, getter, *args, **kwargs): + """Custom getter for layer variables.""" + # Add variables to layer's `(non_)trainable_weights` list(s). + variable = getter(*args, **kwargs) + trainable = kwargs.get("trainable", True) + if trainable and variable not in self._trainable_weights: + self._trainable_weights.append(variable) + elif not trainable and variable not in self._non_trainable_weights: + self._non_trainable_weights.append(variable) + return variable + + def build(self, input_shape): + """Create variables of the Cudnn RNN. + + It can be called manually before `__call__()` or automatically through + `__call__()`. In the former case, subsequent `__call__()`s will skip + creating variables. + Args: + input_shape: network input tensor shape, a python list or a TensorShape + object with 3 dimensions. + Raises: + ValueError: if input_shape has wrong dimension or unknown 3rd dimension. + """ + if self.built: + return + + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape.ndims != 3: + raise ValueError("Expecting input_shape with 3 dims, got %d" % + input_shape.ndims) + if input_shape[-1].value is None: + raise ValueError("The last dimension of the inputs to `CudnnRNN` " + "should be defined. Found `None`.") + self._input_size = input_shape[-1].value + self.input_spec = base_layer.InputSpec(ndim=3, axes={-1: self._input_size}) + + self._set_scope(None) + + # Not using base class `add_variable()` since the it calls + # `tf.get_variable()` with a callable initializer whereas here with a + # tensor. The difference is mandated to support forward-compatibility with + # Cudnn. + with vs.variable_scope( + self._scope, + reuse=self.built, + custom_getter=self._update_trainable_weights): + if self._kernel_initializer is None: + self._kernel_initializer = init_ops.glorot_uniform_initializer( + seed=self._seed, dtype=self._plain_dtype) + if self._bias_initializer is None: + self._bias_initializer = init_ops.constant_initializer( + 0.0, dtype=self._plain_dtype) + + weights = [ + self._kernel_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_weight_shapes + ] + biases = [ + self._bias_initializer(sp, dtype=self._plain_dtype) + for sp in self.canonical_bias_shapes + ] + opaque_params_t = self._canonical_to_opaque(weights, biases) + + if vs.get_variable_scope().partitioner is not None: + logging.warn( + "Partitioner is not supported for Cudnn RNN layer variables, using " + "it will create forward-compatibility issues with future " + "CUDA/CuDNN generations.") + # Initialize opaque params with a tensor. + self.kernel = vs.get_variable( + "opaque_kernel", initializer=opaque_params_t, validate_shape=False) + # Create saveable in the outer scope of the cudnn subgraph, such that + # alternative subgraph with platform-independent rnn cells can load the + # checkpoints directly. + if not (self.built or vs.get_variable_scope().reuse): + self._create_saveable() + self.built = True + + def call(self, inputs, initial_state=None, training=True): + """Runs the forward step for the RNN model. + + Args: + inputs: `3-D` tensor with shape `[time_len, batch_size, input_size]`. + initial_state: a tuple of tensor(s) of shape + `[num_layers * num_dirs, batch_size, num_units]`. If not provided, use + zero initial states. The tuple size is 2 for LSTM and 1 for other RNNs. + training: whether this operation will be used in training or inference. + Returns: + output: a tensor of shape `[time_len, batch_size, num_dirs * num_units]`. + It is a `concat([fwd_output, bak_output], axis=2)`. + output_states: a tuple of tensor(s) of the same shape and structure as + `initial_state`. + Raises: + ValueError: initial_state is not a tuple. + """ + if initial_state is not None and not isinstance(initial_state, tuple): + raise ValueError("Invalid initial_state type: %s, expecting tuple.", + type(initial_state)) + dtype = self.dtype + inputs = ops.convert_to_tensor(inputs, dtype=dtype) + + batch_size = array_ops.shape(inputs)[1] + if initial_state is None: + initial_state = self._zero_state(batch_size) + if self._rnn_mode == CUDNN_LSTM: + h, c = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + else: + h, = initial_state # pylint:disable=unbalanced-tuple-unpacking,unpacking-non-sequence + h = ops.convert_to_tensor(h, dtype=dtype) + if self._rnn_mode == CUDNN_LSTM: + c = ops.convert_to_tensor(c, dtype=dtype) + else: + # For model that doesn't take input_c, replace with a dummy tensor. + c = array_ops.constant([], dtype=dtype) + outputs, (output_h, output_c) = self._forward(inputs, h, c, self.kernel, + training) + if self._rnn_mode == CUDNN_LSTM: + return outputs, (output_h, output_c) + else: + return outputs, (output_h,) + + def state_shape(self, batch_size): + raise NotImplementedError + + def _zero_state(self, batch_size): + res = [] + for sp in self.state_shape(batch_size): + res.append(array_ops.zeros(sp, dtype=self.dtype)) + return tuple(res) + + def _canonical_weight_shape(self, layer): + """Shapes of Cudnn canonical weight tensors for given layer.""" + if layer < 0 or layer >= self._num_layers: + raise ValueError("\'layer\' is not valid, got %s, expecting [%d, %d]" % + (layer, 0, self._num_layers-1)) + if not self._input_size: + raise RuntimeError( + "%s._canonical_weight_shape invoked before input shape is known" % + type(self).__name__) + + input_size = self._input_size + num_units = self._num_units + num_gates = self._num_params_per_layer // 2 + is_bidi = self._direction == CUDNN_RNN_BIDIRECTION + + if layer == 0: + wts_applied_on_inputs = [(num_units, input_size)] * num_gates + else: + if is_bidi: + wts_applied_on_inputs = [(num_units, 2 * num_units)] * num_gates + else: + wts_applied_on_inputs = [(num_units, num_units)] * num_gates + wts_applied_on_hidden_states = [(num_units, num_units)] * num_gates + tf_wts = wts_applied_on_inputs + wts_applied_on_hidden_states + return tf_wts if not is_bidi else tf_wts * 2 + + def _canonical_bias_shape(self, unused_layer): + """Shapes of Cudnn canonical bias tensors for given layer.""" + num_dirs = 1 if self._direction == CUDNN_RNN_UNIDIRECTION else 2 + return [[self._num_units]] * num_dirs * self._num_params_per_layer + + def _canonical_to_opaque(self, cu_weights, cu_biases): + if not self._input_size: + raise RuntimeError( + "%s._canonical_to_opaque invoked before input shape is known" % + type(self).__name__) + return cudnn_rnn_ops.cudnn_rnn_canonical_to_opaque_params( + rnn_mode=self._rnn_mode, + num_layers=self._num_layers, + num_units=self._num_units, + input_size=self._input_size, + weights=cu_weights, + biases=cu_biases, + input_mode=self._input_mode, + seed=self._seed, + dropout=self._dropout, + direction=self._direction) + + def _forward(self, inputs, h, c, opaque_params, training): + output, output_h, output_c = cudnn_rnn_ops._cudnn_rnn( # pylint:disable=protected-access + inputs, + h, + c, + opaque_params, + training, + self._rnn_mode, + input_mode=self._input_mode, + direction=self._direction, + dropout=self._dropout, + seed=self._seed) + return output, (output_h, output_c) + + def _create_saveable(self): + """Create custom saveable for the Cudnn layer. + + Called during layer building process to make sharing checkpoints between + Cudnn and Cudnn-compatible RNNs easy. + Returns: + a `CudnnOpaqueParamsSaveable` object. + Raises: + RuntimeError: if any custom saveable is already created for this layer. + """ + if self._saveable is not None: + raise RuntimeError("Cudnn saveable already created.") + self._saveable = self._saveable_cls( # pylint:disable=not-callable + self.trainable_variables[0], + self.num_layers, + self.num_units, + self.input_size, + self.input_mode, + self.direction, + scope=vs.get_variable_scope(), + name="%s_saveable" % self.trainable_variables[0].op.name) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, self._saveable) + + +class CudnnLSTM(_CudnnRNN): + """Cudnn implementation of LSTM layer.""" + _rnn_mode = CUDNN_LSTM + _num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnLSTMSaveable + + def state_shape(self, batch_size): + """Shape of Cudnn LSTM states. + + Shape is a 2-element tuple. Each is + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return ([self.num_layers * self.num_dirs, batch_size, self.num_units], + [self.num_layers * self.num_dirs, batch_size, self.num_units]) + + +class _CudnnRNNNoInputC(_CudnnRNN): + """Abstract simple CudnnRNN layer without input_c.""" + + def state_shape(self, batch_size): + """Shape of the state of Cudnn RNN cells w/o. input_c. + + Shape is a 1-element tuple, + [num_layers * num_dirs, batch_size, num_units] + Args: + batch_size: an int + Returns: + a tuple of python arrays. + """ + return [self.num_layers * self.num_dirs, batch_size, self.num_units], + + +class CudnnGRU(_CudnnRNNNoInputC): + """Cudnn implementation of the GRU layer.""" + _rnn_mode = CUDNN_GRU + _num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnGRUSaveable + + +class CudnnRNNTanh(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-tanh layer.""" + _rnn_mode = CUDNN_RNN_TANH + _num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNTanhSaveable + + +class CudnnRNNRelu(_CudnnRNNNoInputC): + """Cudnn implementation of the RNN-relu layer.""" + _rnn_mode = CUDNN_RNN_RELU + _num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + _saveable_cls = cudnn_rnn_ops.CudnnRNNReluSaveable diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py index f6eeb016755b66a8ac2a4b4e711543ebdf468269..9f748996934ca608838e57756a96c35c67feaac9 100644 --- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py +++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py @@ -18,6 +18,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.cudnn_rnn.ops import gen_cudnn_rnn_ops +from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.contrib.rnn.python.ops import lstm_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import common_shapes @@ -65,7 +66,7 @@ class CudnnCompatibleLSTMCell(lstm_ops.LSTMBlockCell): def __init__(self, num_units, reuse=None): super(CudnnCompatibleLSTMCell, self).__init__( - num_units, forget_bias=0, clip_cell=False, use_peephole=False, + num_units, forget_bias=0, cell_clip=None, use_peephole=False, reuse=reuse) self._names.update({"scope": "cudnn_compatible_lstm_cell"}) @@ -121,18 +122,18 @@ class CudnnCompatibleGRUCell(rnn_cell_impl.GRUCell): bias_ones = init_ops.constant_initializer(1.0, dtype=dtype) # pylint: disable=protected-access value = math_ops.sigmoid( - rnn_cell_impl._linear([inputs, state], 2 * self._num_units, True, + core_rnn_cell._linear([inputs, state], 2 * self._num_units, True, bias_ones, self._kernel_initializer)) r, u = array_ops.split(value=value, num_or_size_splits=2, axis=1) # pylint: enable=protected-access with vs.variable_scope("candidate"): # pylint: disable=protected-access with vs.variable_scope("input_projection"): - hi = rnn_cell_impl._linear(inputs, self._num_units, True, + hi = core_rnn_cell._linear(inputs, self._num_units, True, self._bias_initializer, self._kernel_initializer) with vs.variable_scope("hidden_projection"): - hh = r * (rnn_cell_impl._linear(state, self._num_units, True, + hh = r * (core_rnn_cell._linear(state, self._num_units, True, self._bias_initializer, self._kernel_initializer)) # pylint: enable=protected-access @@ -717,12 +718,6 @@ _cudnn_rnn_common_doc_string = """ """ -def _check_direction(direction): - if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): - raise ValueError("Invalid direction: %s, expect %s or %s" % - (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) - - def _check_rnn_mode(rnn_mode): if rnn_mode not in (CUDNN_LSTM, CUDNN_GRU, CUDNN_RNN_TANH, CUDNN_RNN_RELU): raise ValueError("Invalid rnn_mode: %s, expect one of (%s, %s, %s, %s)" % @@ -737,14 +732,31 @@ def _get_seed(seed): return seed, seed2 +def check_direction(direction): + """Check validity of direction.""" + if direction not in (CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION): + raise ValueError("Invalid direction: %s, expecting %s or %s" % + (direction, CUDNN_RNN_UNIDIRECTION, CUDNN_RNN_BIDIRECTION)) + + +def check_input_mode(input_mode): + if input_mode not in (CUDNN_INPUT_LINEAR_MODE, CUDNN_INPUT_SKIP_MODE, + CUDNN_INPUT_AUTO_MODE): + raise ValueError("Invalid input_mode: %s, expect one of (%s, %s, %s)" % + (input_mode, CUDNN_INPUT_LINEAR_MODE, + CUDNN_INPUT_SKIP_MODE, CUDNN_INPUT_AUTO_MODE)) + + def _get_num_params(rnn_mode, num_layers, direction): """Return num params for given Cudnn config.""" if rnn_mode == CUDNN_LSTM: - num_params_per_layer = 8 + num_params_per_layer = CUDNN_LSTM_PARAMS_PER_LAYER elif rnn_mode == CUDNN_GRU: - num_params_per_layer = 6 - elif rnn_mode in (CUDNN_RNN_RELU, CUDNN_RNN_TANH): - num_params_per_layer = 2 + num_params_per_layer = CUDNN_GRU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_RELU: + num_params_per_layer = CUDNN_RNN_RELU_PARAMS_PER_LAYER + elif rnn_mode == CUDNN_RNN_TANH: + num_params_per_layer = CUDNN_RNN_TANH_PARAMS_PER_LAYER else: raise ValueError("Invalid \'rnn_mode\': %s", rnn_mode) num_params = num_layers * num_params_per_layer @@ -794,7 +806,8 @@ def _cudnn_rnn(inputs, outputs, output_h, output_c """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn( input=inputs, @@ -1017,16 +1030,16 @@ def cudnn_rnn_tanh(inputs, seed, name) -def cudnn_rnn_params_to_canonical(rnn_mode, - num_layers, - num_units, - input_size, - params, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_to_canonical(rnn_mode, + num_layers, + num_units, + input_size, + params, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Convert cudnn opaque params to canonical. Args: @@ -1058,7 +1071,8 @@ def cudnn_rnn_params_to_canonical(rnn_mode, """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) num_params = _get_num_params(rnn_mode, num_layers, direction) seed, seed2 = random_seed.get_seed(seed) weights, biases = gen_cudnn_rnn_ops.cudnn_rnn_params_to_canonical( @@ -1077,17 +1091,17 @@ def cudnn_rnn_params_to_canonical(rnn_mode, return weights, biases -def cudnn_rnn_canonical_to_params(rnn_mode, - num_layers, - num_units, - input_size, - weights, - biases, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_canonical_to_opaque_params(rnn_mode, + num_layers, + num_units, + input_size, + weights, + biases, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dropout=0, + seed=0, + name=None): """Converts params from the canonical format to a specific format of cuDNN. Args: @@ -1119,7 +1133,8 @@ def cudnn_rnn_canonical_to_params(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_canonical_to_params( rnn_mode=rnn_mode, @@ -1136,16 +1151,16 @@ def cudnn_rnn_canonical_to_params(rnn_mode, name=name) -def cudnn_opaque_params_size(rnn_mode, - num_layers, - num_units, - input_size, - input_mode=CUDNN_INPUT_LINEAR_MODE, - direction=CUDNN_RNN_UNIDIRECTION, - dtype=dtypes.float32, - dropout=0, - seed=0, - name=None): +def cudnn_rnn_opaque_params_size(rnn_mode, + num_layers, + num_units, + input_size, + input_mode=CUDNN_INPUT_LINEAR_MODE, + direction=CUDNN_RNN_UNIDIRECTION, + dtype=dtypes.float32, + dropout=0, + seed=0, + name=None): """Returns opaque params size for specific Cudnn config. Args: @@ -1176,7 +1191,8 @@ def cudnn_opaque_params_size(rnn_mode, ValueError: if rnn_mode or direction is invalid. """ _check_rnn_mode(rnn_mode) - _check_direction(direction) + check_direction(direction) + check_input_mode(input_mode) seed, seed2 = random_seed.get_seed(seed) return gen_cudnn_rnn_ops.cudnn_rnn_params_size( rnn_mode=rnn_mode, @@ -1278,7 +1294,7 @@ class _CudnnRNN(object): Returns: The calculated parameter buffer size. """ - return cudnn_opaque_params_size( + return cudnn_rnn_opaque_params_size( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1327,7 +1343,7 @@ class _CudnnRNN(object): Returns: A function for the specific-to-canonical conversion. """ - return cudnn_rnn_params_to_canonical( + return cudnn_rnn_opaque_params_to_canonical( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, @@ -1348,7 +1364,7 @@ class _CudnnRNN(object): Returns: A function for the canonical-to-params-to-specific conversion.. """ - return cudnn_rnn_canonical_to_params( + return cudnn_rnn_canonical_to_opaque_params( rnn_mode=self._rnn_mode, num_layers=self._num_layers, num_units=self._num_units, diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD index 1c3a798c5fa722eb75c3d830aae643bba0733559..eaede0e00ecf1986873d50709d135d3f4b3ac9cd 100644 --- a/tensorflow/contrib/data/BUILD +++ b/tensorflow/contrib/data/BUILD @@ -4,16 +4,39 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load( + "//tensorflow:tensorflow.bzl", + "tf_custom_op_library", + "tf_gen_op_libs", +) + py_library( name = "data", srcs = ["__init__.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sloppy_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:util", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +tf_custom_op_library( + name = "_prefetching_ops.so", + srcs = [ + "ops/prefetching_ops.cc", ], + deps = [ + "//tensorflow/contrib/data/kernels:prefetching_kernels", + ], +) + +tf_gen_op_libs( + op_lib_names = ["prefetching_ops"], ) filegroup( diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md index 7c59a1ffc37085f17f8f4e693c0bc874c77f914a..848782e8d89b8670caf3b45de4912a7e0855c102 100644 --- a/tensorflow/contrib/data/README.md +++ b/tensorflow/contrib/data/README.md @@ -1,8 +1,39 @@ `tf.contrib.data` API ===================== -This directory contains the Python API for the `tf.contrib.data.Dataset` and -`tf.contrib.data.Iterator` classes, which can be used to build input pipelines. +NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead. +We are continuing to support existing code using the `tf.contrib.data` APIs in +the current version of TensorFlow, but will eventually remove support. The +`tf.data` APIs are subject to backwards compatibility guarantees. -The documentation for this API has moved to the programmers' -guide, [here](../../docs_src/programmers_guide/datasets.md). +Porting your code to `tf.data` +------------------------------ + +The `tf.contrib.data.Dataset` class has been renamed to `tf.data.Dataset`, and +the `tf.contrib.data.Iterator` class has been renamed to `tf.data.Iterator`. +Most code can be ported by removing `.contrib` from the names of the classes. +However, there are some small differences, which are outlined below. + +The arguments accepted by the `Dataset.map()` transformation have changed: + +* `dataset.map(..., num_threads=T)` is now `dataset.map(num_parallel_calls=T)`. +* `dataset.map(..., output_buffer_size=B)` is now + `dataset.map(...).prefetch(B)`. + +Some transformations have been removed from `tf.data.Dataset`, and you must +instead apply them using `Dataset.apply()` transformation. The full list of +changes is as follows: + +* `dataset.dense_to_sparse_batch(...)` is now + `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`. +* `dataset.enumerate(...)` is now + `dataset.apply(tf.contrib.data.enumerate_dataset(...))`. +* `dataset.group_by_window(...)` is now + `dataset.apply(tf.contrib.data.group_by_window(...))`. +* `dataset.ignore_errors()` is now + `dataset.apply(tf.contrib.data.ignore_errors())`. +* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`. + +The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have +been removed from the API. Please open a GitHub issue if you have a need for +either of these. diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py index 67dff0a4ab093c4375b4d973ce97eb5ea04e8e62..824ac4298f88a0372743324793f6de453dae71c8 100644 --- a/tensorflow/contrib/data/__init__.py +++ b/tensorflow/contrib/data/__init__.py @@ -14,6 +14,8 @@ # ============================================================================== """`tf.contrib.data.Dataset` API for input pipelines. +See the @{$datasets$Importing Data} Programmer's Guide for an overview. + @@Dataset @@Iterator @@TFRecordDataset @@ -25,11 +27,14 @@ @@enumerate_dataset @@group_by_window @@ignore_errors +@@make_saveable_from_iterator @@read_batch_features @@unbatch +@@parallel_interleave @@rejection_resample @@sloppy_interleave +@@get_single_element """ from __future__ import absolute_import @@ -37,21 +42,25 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import -from tensorflow.contrib.data.python.ops.dataset_ops import batch_and_drop_remainder + +from tensorflow.contrib.data.python.ops.batching import batch_and_drop_remainder +from tensorflow.contrib.data.python.ops.batching import dense_to_sparse_batch +from tensorflow.contrib.data.python.ops.batching import unbatch from tensorflow.contrib.data.python.ops.dataset_ops import Dataset -from tensorflow.contrib.data.python.ops.dataset_ops import dense_to_sparse_batch -from tensorflow.contrib.data.python.ops.dataset_ops import enumerate_dataset -from tensorflow.contrib.data.python.ops.dataset_ops import FixedLengthRecordDataset -from tensorflow.contrib.data.python.ops.dataset_ops import group_by_window -from tensorflow.contrib.data.python.ops.dataset_ops import ignore_errors -from tensorflow.contrib.data.python.ops.dataset_ops import read_batch_features -from tensorflow.contrib.data.python.ops.dataset_ops import rejection_resample -from tensorflow.contrib.data.python.ops.dataset_ops import SqlDataset -from tensorflow.contrib.data.python.ops.dataset_ops import TextLineDataset -from tensorflow.contrib.data.python.ops.dataset_ops import TFRecordDataset -from tensorflow.contrib.data.python.ops.dataset_ops import unbatch -from tensorflow.contrib.data.python.ops.sloppy_ops import sloppy_interleave -from tensorflow.python.data.ops.dataset_ops import Iterator +from tensorflow.contrib.data.python.ops.dataset_ops import get_single_element +from tensorflow.contrib.data.python.ops.enumerate_ops import enumerate_dataset +from tensorflow.contrib.data.python.ops.error_ops import ignore_errors +from tensorflow.contrib.data.python.ops.grouping import group_by_window +from tensorflow.contrib.data.python.ops.interleave_ops import parallel_interleave +from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave +from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator +from tensorflow.contrib.data.python.ops.readers import FixedLengthRecordDataset +from tensorflow.contrib.data.python.ops.readers import read_batch_features +from tensorflow.contrib.data.python.ops.readers import SqlDataset +from tensorflow.contrib.data.python.ops.readers import TextLineDataset +from tensorflow.contrib.data.python.ops.readers import TFRecordDataset +from tensorflow.contrib.data.python.ops.resampling import rejection_resample +from tensorflow.python.data.ops.iterator_ops import Iterator # pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/contrib/data/kernels/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..4cb53741ebf8cd0db41b382c878bd2ccd1dcf7f1 --- /dev/null +++ b/tensorflow/contrib/data/kernels/BUILD @@ -0,0 +1,29 @@ +# Description: +# Contains kernels for datasets and iterators. +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "prefetching_kernels", + srcs = ["prefetching_kernels.cc"], + deps = [ + "//tensorflow/core:framework_headers_lib", + "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", + ], + alwayslink = 1, +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), +) diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/contrib/data/kernels/prefetching_kernels.cc new file mode 100644 index 0000000000000000000000000000000000000000..c9a3537c70c711290fb1111a1594e6dea3bc07a9 --- /dev/null +++ b/tensorflow/contrib/data/kernels/prefetching_kernels.cc @@ -0,0 +1,378 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/util/device_name_utils.h" + +namespace tensorflow { + +struct BufferElement { + // The producer sets `status` if getting the input element fails. + Status status; + // The buffered data element. + std::vector value; +}; + +using FunctionBufferCallback = std::function; + +class FunctionBufferingResource : public ResourceBase { + public: + FunctionBufferingResource(FunctionLibraryRuntime* lib, + const NameAttrList& func, int64 buffer_size, + const string& source_device, + const string& target_device, + const std::vector& func_args, + int64 thread_pool_size) + : lib_(lib), + func_(func), + buffer_size_(buffer_size), + source_device_(source_device), + target_device_(target_device), + func_args_(func_args), + thread_pool_(new thread::ThreadPool(Env::Default(), ThreadOptions(), + "buffer_resource", thread_pool_size, + false /* low_latency_hint */)), + handle_(kInvalidHandle), + is_buffering_(false), + end_of_sequence_(false), + cancelled_(false) { + runner_ = [this](std::function c) { + thread_pool_->Schedule(std::move(c)); + }; + } + + ~FunctionBufferingResource() override { + Cancel(); + { + mutex_lock l(mu_); + while (is_buffering_) { + cond_var_.wait(l); + } + } + delete thread_pool_; + } + + string DebugString() override { + return strings::StrCat("FunctionBufferingResource. Size: ", buffer_size_, + "; target_device: ", target_device_); + } + + // Instantiates the function the first time it's called. After that it caches + // the handle. + Status Instantiate() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + // Re-use existing handle if it's been set, effectively caching it. + if (handle_ != kInvalidHandle) { + return Status::OK(); + } + AttrValueMap attr_values = func_.attr(); + AttrValue v; + v.set_s(target_device_); + AddAttr("_target", v, &attr_values); + + return lib_->Instantiate(func_.name(), AttrSlice(&attr_values), &handle_); + } + + // Returns true if we've got to the end of the sequence and exhausted the + // buffer. + bool Finished() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + return end_of_sequence_ && buffer_.empty(); + } + + // Cancels any buffering / prefetching going on. + void Cancel() LOCKS_EXCLUDED(mu_) { + mutex_lock l(mu_); + cancelled_ = true; + } + + // If the buffer has anything, runs `callback` on the first element in the + // buffer, else schedules the `callback` to be called. Requires `args` and + // `lib` in case more function calls need to be scheduled. + void MaybeGet(FunctionBufferCallback callback) LOCKS_EXCLUDED(mu_) { + bool start_buffering = false; + bool produced_output = false; + BufferElement buffer_element; + { + mutex_lock l(mu_); + if (!is_buffering_ && !end_of_sequence_) { + start_buffering = true; + } + if (!buffer_.empty()) { + produced_output = true; + std::swap(buffer_element, buffer_.front()); + buffer_.pop_front(); + } else { + produced_output = false; + requests_.push_back(std::move(callback)); + } + } + if (produced_output) { + callback(buffer_element); + } + if (start_buffering) { + FillBuffer(); + } + } + + private: + void FillBuffer() LOCKS_EXCLUDED(mu_) { + FunctionLibraryRuntime::Handle handle; + std::vector cancellation_callbacks; + std::vector cancellation_buffer_elements; + bool cancelled = false; + { + mutex_lock l(mu_); + handle = handle_; + if (cancelled_) { + cancelled = true; + // Run through and fulfill all pending requests, if possible. + while (!requests_.empty()) { + if (!buffer_.empty()) { + cancellation_buffer_elements.push_back(std::move(buffer_.front())); + buffer_.pop_front(); + cancellation_callbacks.push_back(std::move(requests_.front())); + requests_.pop_front(); + } else { + LOG(ERROR) << "Buffer ran out of elements and we couldn't satisfy: " + << requests_.size() << " requests"; + break; + } + } + is_buffering_ = false; + } else { + is_buffering_ = true; + } + } + if (cancelled) { + for (int i = 0; i < cancellation_callbacks.size(); ++i) { + cancellation_callbacks[i](cancellation_buffer_elements[i]); + } + // We only wait on cond_var_ in the destructor, so there would atmost be + // one waiter to notify. + cond_var_.notify_one(); + return; + } + FunctionLibraryRuntime::Options opts; + // Copied from CapturedFunction::generate_step_id(); + opts.step_id = -std::abs(static_cast(random::New64())); + opts.runner = &runner_; + opts.source_device = source_device_; + AllocatorAttributes arg_alloc_attr; + arg_alloc_attr.set_on_host(true); + opts.args_alloc_attrs.push_back(arg_alloc_attr); + if (opts.source_device != target_device_) { + opts.remote_execution = true; + } + opts.create_rendezvous = true; + auto* rets = new std::vector; + lib_->Run(opts, handle, func_args_, rets, + [this, rets](const Status& status) { + FunctionBufferCallback callback = nullptr; + BufferElement buffer_front; + bool restart_buffering = false; + { + mutex_lock l(mu_); + BufferElement buffer_element; + buffer_element.status = status; + if (!status.ok()) { + end_of_sequence_ = true; + is_buffering_ = false; + buffer_.push_back(std::move(buffer_element)); + return; + } + buffer_element.value.swap(*rets); + buffer_.push_back(std::move(buffer_element)); + if (!requests_.empty()) { + buffer_front = std::move(buffer_.front()); + buffer_.pop_front(); + callback = std::move(requests_.front()); + requests_.pop_front(); + } + if (buffer_.size() < buffer_size_) { + restart_buffering = true; + } else { + is_buffering_ = false; + } + } + if (callback != nullptr) { + callback(buffer_front); + } + if (restart_buffering) { + FillBuffer(); + } + }); + } + + mutex mu_; + FunctionLibraryRuntime* lib_; + NameAttrList func_; + const int64 buffer_size_; + const string source_device_; + const string target_device_; + const std::vector func_args_; + thread::ThreadPool* thread_pool_; + FunctionLibraryRuntime::Handle handle_ GUARDED_BY(mu_); + std::deque buffer_ GUARDED_BY(mu_); + std::deque requests_ GUARDED_BY(mu_); + std::function)> runner_ = nullptr; + bool is_buffering_ GUARDED_BY(mu_); + bool end_of_sequence_ GUARDED_BY(mu_); + bool cancelled_ GUARDED_BY(mu_); + condition_variable cond_var_; +}; + +class FunctionBufferResourceHandleOp : public OpKernel { + public: + explicit FunctionBufferResourceHandleOp(OpKernelConstruction* ctx) + : OpKernel(ctx) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("buffer_size", &buffer_size_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("container", &container_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("shared_name", &name_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("thread_pool_size", &thread_pool_size_)); + } + + void Compute(OpKernelContext* ctx) override { + const Tensor* string_arg; + OP_REQUIRES_OK(ctx, ctx->input("string_arg", &string_arg)); + std::vector func_args; + func_args.push_back(*string_arg); + + // Obtain and canonicalize target_device. + const Tensor* target_arg; + OP_REQUIRES_OK(ctx, ctx->input("target_device", &target_arg)); + const string& target_device = + DeviceNameUtils::CanonicalizeDeviceName(target_arg->scalar()()); + + FunctionLibraryRuntime* lib = ctx->function_library(); + OP_REQUIRES(ctx, lib != nullptr, + errors::Internal("No function library is provided.")); + + const string& source_device = ctx->device()->name(); + + ContainerInfo cinfo; + OP_REQUIRES_OK(ctx, cinfo.Init(ctx->resource_manager(), def())); + // Create the resource. + FunctionBufferingResource* buffer; + OP_REQUIRES_OK( + ctx, ctx->resource_manager()->LookupOrCreate( + cinfo.container(), cinfo.name(), &buffer, + [lib, &source_device, &target_device, func_args, + this](FunctionBufferingResource** ptr) { + *ptr = new FunctionBufferingResource( + lib, func_, buffer_size_, source_device, target_device, + func_args, thread_pool_size_); + return Status::OK(); + })); + OP_REQUIRES_OK(ctx, buffer->Instantiate()); + + OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput( + ctx, 0, cinfo.container(), cinfo.name(), + MakeTypeIndex())); + } + + private: + NameAttrList func_; + int64 buffer_size_; + string container_; + string name_; + int64 thread_pool_size_; +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") + .Device(DEVICE_CPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") + .Device(DEVICE_GPU) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource") + .Device(DEVICE_SYCL) + .HostMemory("resource") + .HostMemory("string_arg") + .HostMemory("target_device"), + FunctionBufferResourceHandleOp); +#endif // TENSORFLOW_USE_SYCL + +// Prefetches and fills up a buffer by calling a function that provides the +// elements to buffer. +class FunctionBufferingResourceGetNextOp : public AsyncOpKernel { + public: + explicit FunctionBufferingResourceGetNextOp(OpKernelConstruction* ctx) + : AsyncOpKernel(ctx) {} + + ~FunctionBufferingResourceGetNextOp() override {} + + void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override { + ResourceHandle handle; + OP_REQUIRES_OK_ASYNC( + ctx, HandleFromInput(ctx, "function_buffer_resource", &handle), done); + FunctionBufferingResource* buffer = nullptr; + OP_REQUIRES_OK_ASYNC( + ctx, LookupResource(ctx, handle, &buffer), + done); + core::ScopedUnref s(buffer); + + if (buffer->Finished()) { + ctx->SetStatus(errors::OutOfRange("end_of_sequence")); + done(); + return; + } + + FunctionBufferCallback callback = + [ctx, done](const BufferElement& buffer_element) { + Status s = buffer_element.status; + if (!s.ok()) { + ctx->SetStatus(s); + done(); + return; + } + for (size_t i = 0; i < buffer_element.value.size(); ++i) { + ctx->set_output(i, buffer_element.value[i]); + } + done(); + }; + buffer->MaybeGet(std::move(callback)); + } +}; + +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") + .Device(DEVICE_CPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") + .Device(DEVICE_GPU) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#if TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext") + .Device(DEVICE_SYCL) + .HostMemory("function_buffer_resource"), + FunctionBufferingResourceGetNextOp); +#endif // TENSORFLOW_USE_SYCL + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/ops/prefetching_ops.cc b/tensorflow/contrib/data/ops/prefetching_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..23cb62b6f0dbfed15667dd00ae0039b33aa944d4 --- /dev/null +++ b/tensorflow/contrib/data/ops/prefetching_ops.cc @@ -0,0 +1,58 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { + +REGISTER_OP("FunctionBufferingResource") + .Input("string_arg: string") + .Input("target_device: string") + .Output("resource: resource") + .Attr("shared_name: string") + .Attr("container: string") + .Attr("f: func") + .Attr("buffer_size: int") + .Attr("thread_pool_size: int") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Creates a resource that fills up a buffer by making function calls. + +string_arg: String argument to the function call. +target_device: Target device to execute the function on. +resource: Handle to the resource created. +f: Function to be executed. +buffer_size: Size of the buffer. +thread_pool_size: Size of the threadpool doing the prefetching. +container: If non-empty, this resource is placed in the given container. + Otherwise, a default container is used. +shared_name: If non-empty, this resource will be shared under the given name + across multiple sessions. +)doc"); + +REGISTER_OP("FunctionBufferingResourceGetNext") + .Input("function_buffer_resource: resource") + .Attr("output_types: list(type)") + .Output("output: output_types") + .SetShapeFn(shape_inference::UnknownShape) + .Doc(R"doc( +Gets the next element from a FunctionBufferingResource. + +function_buffer_resource: The FunctionBufferingResource handle. +output: A list of return values. +output_types: The type list for the return values. +)doc"); + +} // namespace tensorflow diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD index aa047803e9e3a8095fbd7af5081457c08110c906..5877f42dcf9e99bca27ba0e6ce222c556dfbd159 100644 --- a/tensorflow/contrib/data/python/kernel_tests/BUILD +++ b/tensorflow/contrib/data/python/kernel_tests/BUILD @@ -7,55 +7,54 @@ exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "py_test") py_test( - name = "iterator_ops_test", + name = "batch_dataset_op_test", size = "small", - srcs = ["iterator_ops_test.py"], + srcs = ["batch_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = ["no_pip"], deps = [ + ":dataset_serialization_test", "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:gradients", "//tensorflow/python:math_ops", - "//tensorflow/python:parsing_ops", - "//tensorflow/python:script_ops", - "//tensorflow/python:training", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:util", "//third_party/py/numpy", ], ) py_test( - name = "iterator_ops_cluster_test", + name = "bucketing_test", size = "small", - srcs = ["iterator_ops_cluster_test.py"], + srcs = ["bucketing_test.py"], srcs_version = "PY2AND3", - tags = ["no_windows"], deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:functional_ops", - "//tensorflow/python:training", + "//tensorflow/python:math_ops", + "//tensorflow/python:string_ops", + "//tensorflow/python:tensor_shape", "//third_party/py/numpy", ], ) py_test( - name = "batch_dataset_op_test", + name = "cache_dataset_op_test", size = "small", - srcs = ["batch_dataset_op_test.py"], + srcs = ["cache_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", @@ -64,32 +63,26 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:math_ops", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:string_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:util", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) py_test( - name = "bucketing_test", + name = "concatenate_dataset_op_test", size = "small", - srcs = ["bucketing_test.py"], + srcs = ["concatenate_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:string_ops", "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) @@ -105,6 +98,8 @@ py_test( ], deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -115,11 +110,31 @@ py_test( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", + "//tensorflow/python:training", "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) +py_library( + name = "dataset_serialization_test", + testonly = 1, + srcs = [ + "dataset_serialization_test_base.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + py_test( name = "filter_dataset_op_test", size = "small", @@ -131,6 +146,7 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:functional_ops", "//tensorflow/python:math_ops", "//third_party/py/numpy", ], @@ -144,30 +160,87 @@ py_test( deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:session", "//tensorflow/python:training", "//third_party/py/numpy", ], ) py_test( - name = "sloppy_transformation_dataset_op_test", + name = "interleave_dataset_op_test", size = "small", - srcs = ["sloppy_transformation_dataset_op_test.py"], + srcs = ["interleave_dataset_op_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", # b/67958761 + ], deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/contrib/data/python/ops:sloppy_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:training", + "//third_party/py/numpy", + ], +) + +py_test( + name = "iterator_ops_cluster_test", + size = "small", + srcs = ["iterator_ops_cluster_test.py"], + srcs_version = "PY2AND3", + tags = ["no_windows"], + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:session", + "//tensorflow/python/data/ops:iterator_ops", + ], +) + +py_test( + name = "iterator_ops_test", + size = "small", + srcs = ["iterator_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:readers", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:function", + "//tensorflow/python:functional_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:io_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:session", "//tensorflow/python:training", + "//tensorflow/python/data/ops:iterator_ops", "//third_party/py/numpy", ], ) @@ -194,12 +267,15 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:functional_ops", "//tensorflow/python:io_ops", "//tensorflow/python:lookup_ops", @@ -207,6 +283,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:script_ops", "//tensorflow/python:string_ops", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//third_party/py/numpy", @@ -220,6 +297,8 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", @@ -227,9 +306,13 @@ py_test( "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", + "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -239,32 +322,23 @@ py_test( srcs = ["reader_dataset_ops_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:io_ops", "//tensorflow/python:lib", "//tensorflow/python:parsing_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:training", "//tensorflow/python:util", - ], -) - -py_test( - name = "sql_dataset_op_test", - size = "small", - srcs = ["sql_dataset_op_test.py"], - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:framework", - "//tensorflow/python:platform_test", + "//tensorflow/python/data/ops:iterator_ops", ], ) @@ -277,11 +351,11 @@ py_test( tags = ["noasan"], deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:transformation_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:string_ops", "//tensorflow/python:util", - "//tensorflow/python:variables", "//third_party/py/numpy", ], ) @@ -302,47 +376,50 @@ py_test( ) py_test( - name = "shuffle_dataset_op_test", + name = "shard_dataset_op_test", size = "small", - srcs = ["shuffle_dataset_op_test.py"], + srcs = ["shard_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//third_party/py/numpy", ], ) py_test( - name = "shard_dataset_op_test", + name = "shuffle_dataset_op_test", size = "small", - srcs = ["shard_dataset_op_test.py"], + srcs = ["shuffle_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:iterator_ops", + "//third_party/py/numpy", ], ) py_test( - name = "cache_dataset_op_test", + name = "sql_dataset_op_test", size = "small", - srcs = ["cache_dataset_op_test.py"], + srcs = ["sql_dataset_op_test.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:readers", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:errors", - "//tensorflow/python:variables", - "//third_party/py/numpy", ], ) @@ -353,26 +430,35 @@ py_test( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:iterator_ops", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python/data/util:nest", "//third_party/py/numpy", ], ) py_test( - name = "concatenate_dataset_op_test", + name = "prefetching_ops_test", size = "small", - srcs = ["concatenate_dataset_op_test.py"], + srcs = ["prefetching_ops_test.py"], srcs_version = "PY2AND3", + tags = [ + "manual", + "no_oss", + ], deps = [ "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/contrib/data/python/ops:prefetching_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", - "//tensorflow/python:tensor_shape", - "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:resource_variable_ops", ], ) diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py index 4a7fb1b8b067fddd2884e3e61b07c0199d02dbec..670f622c3c372dd08870390298f2e28db7e85596 100644 --- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py @@ -21,6 +21,8 @@ import math import numpy as np +from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -230,7 +232,7 @@ class BatchDatasetTest(test.TestCase): components = np.random.randint(12, size=(100,)).astype(np.int32) iterator = (dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([x], x)).apply( - dataset_ops.dense_to_sparse_batch(4, [12])) + batching.dense_to_sparse_batch(4, [12])) .make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -252,11 +254,52 @@ class BatchDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testDenseToSparseBatchDatasetWithUnknownShape(self): + components = np.random.randint(5, size=(40,)).astype(np.int32) + iterator = (dataset_ops.Dataset.from_tensor_slices(components) + .map(lambda x: array_ops.fill([x, x], x)).apply( + batching.dense_to_sparse_batch( + 4, [5, -1])).make_initializable_iterator()) + init_op = iterator.initializer + get_next = sparse_tensor.SparseTensor(*iterator.get_next()) + + with self.test_session() as sess: + sess.run(init_op) + + for start in range(0, len(components), 4): + results = sess.run(get_next) + self.assertAllEqual( + [[i, j, z] for i, c in enumerate(components[start:start+4]) + for j in range(c) for z in range(c)], results.indices) + self.assertAllEqual( + [c for c in components[start:start+4] + for _ in range(c) for _ in range(c)], + results.values) + self.assertAllEqual( + [min(4, len(components) - start), + 5, + np.max(components[start:start+4])], + results.dense_shape) + + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testDenseToSparseBatchDatasetWithInvalidShape(self): + input_tensor = array_ops.constant([[1]]) + iterator = (dataset_ops.Dataset.from_tensors(input_tensor) + .apply(batching.dense_to_sparse_batch(4, [-2])) + .make_initializable_iterator()) + init_op = iterator.initializer + + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dimension -2 must be >= -1"): + sess.run(init_op) + def testDenseToSparseBatchDatasetShapeErrors(self): input_tensor = array_ops.placeholder(dtypes.int32) iterator = (dataset_ops.Dataset.from_tensors(input_tensor).apply( - dataset_ops.dense_to_sparse_batch(4, [12])) - .make_initializable_iterator()) + batching.dense_to_sparse_batch(4, [12])).make_initializable_iterator()) init_op = iterator.initializer get_next = sparse_tensor.SparseTensor(*iterator.get_next()) @@ -279,7 +322,7 @@ class BatchDatasetTest(test.TestCase): expected_types = (dtypes.int32,) * 3 data = data.batch(2) self.assertEqual(expected_types, data.output_types) - data = data.apply(dataset_ops.unbatch()) + data = data.apply(batching.unbatch()) self.assertEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() @@ -298,7 +341,7 @@ class BatchDatasetTest(test.TestCase): expected_types = ((dtypes.int32,),) * 3 data = data.batch(2) self.assertEqual(expected_types, data.output_types) - data = data.apply(dataset_ops.unbatch()) + data = data.apply(batching.unbatch()) self.assertEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() @@ -319,7 +362,7 @@ class BatchDatasetTest(test.TestCase): expected_types = ((dtypes.int32, dtypes.string),) * 3 data = data.batch(2) self.assertAllEqual(expected_types, data.output_types) - data = data.apply(dataset_ops.unbatch()) + data = data.apply(batching.unbatch()) self.assertAllEqual(expected_types, data.output_types) iterator = data.make_one_shot_iterator() @@ -342,8 +385,8 @@ class BatchDatasetTest(test.TestCase): batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - iterator = (dataset_ops.Dataset.from_tensor_slices(components) - .apply(dataset_ops.batch_and_drop_remainder(batch_size)) + iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size)) .make_initializable_iterator()) next_element = iterator.get_next() @@ -367,8 +410,8 @@ class BatchDatasetTest(test.TestCase): dtypes.int32, shape=[20, 30]))) # Test with a statically known batch size. - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .apply(dataset_ops.batch_and_drop_remainder(128))) + dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(128))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([128], dataset.output_shapes[1][0].as_list()) @@ -377,13 +420,131 @@ class BatchDatasetTest(test.TestCase): # Test with a dynamic batch size: the static shape will be unknown, because # `batch_size` is a placeholder. batch_size = array_ops.placeholder(dtypes.int64) - dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .apply(dataset_ops.batch_and_drop_remainder(batch_size))) + dataset = (dataset_ops.Dataset.from_tensor_slices(components).apply( + batching.batch_and_drop_remainder(batch_size))) self.assertIs(None, dataset.output_shapes[0].ndims) self.assertEqual([None], dataset.output_shapes[1][0].as_list()) self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list()) + def testBatchAndMapDataset(self): + """Test a dataset that maps a TF function across its input elements.""" + # The pipeline is TensorSliceDataset -> + # RepeatDataset(count) -> BatchAndMapDataset(square_3, batch_size). + components = (np.arange(7), + np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis], + np.array(37.0) * np.arange(7)) + + count = array_ops.placeholder(dtypes.int64, shape=[]) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + + def _map_fn(x, y, z): + return math_ops.square(x), math_ops.square(y), math_ops.square(z) + + iterator = (dataset_ops.Dataset.from_tensor_slices(components).repeat(count) + .apply(batching.map_and_batch(_map_fn, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual([[None] + list(c.shape[1:]) for c in components], + [t.shape.as_list() for t in get_next]) + + with self.test_session() as sess: + # Batch of a finite input, where the batch_size divides the + # total number of elements. + sess.run(init_op, feed_dict={count: 28, batch_size: 14}) + num_batches = (28 * 7) // 14 + for i in range(num_batches): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(14): + self.assertAllEqual(component[(i*14 + j) % 7]**2, + result_component[j]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of a finite input, where the batch_size does not + # divide the total number of elements. + sess.run(init_op, feed_dict={count: 14, batch_size: 8}) + + # We expect (num_batches - 1) full-sized batches. + num_batches = int(math.ceil((14 * 7) / 8)) + for i in range(num_batches - 1): + result = sess.run(get_next) + for component, result_component in zip(components, result): + for j in range(8): + self.assertAllEqual(component[(i*8 + j) % 7]**2, + result_component[j]) + # The last batch should fail with `OutOfRange`. + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Batch of an empty input should fail straight away. + sess.run(init_op, feed_dict={count: 0, batch_size: 8}) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Empty batch should be an initialization time error. + with self.assertRaises(errors.InvalidArgumentError): + sess.run(init_op, feed_dict={count: 14, batch_size: 0}) + + def testBatchAndMapDatasetFails(self): + """Test a dataset that maps a TF function across its input elements.""" + dataset = dataset_ops.Dataset.from_tensors( + array_ops.check_numerics( + constant_op.constant(1.0) / constant_op.constant(0.0), "oops")) + batch_size = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = (dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + with self.test_session() as sess: + with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"): + sess.run(init_op, feed_dict={batch_size: 14}) + + def testBatchAndMapDatasetShapeMismatch(self): + """Test a dataset that maps a TF function across its input elements.""" + def generator(): + yield [1] + yield [2] + yield [3] + yield [[4, 5, 6]] + + dataset = dataset_ops.Dataset.from_generator( + generator, output_types=dtypes.int32) + batch_size = 4 + iterator = ( + dataset.apply(batching.map_and_batch(lambda x: x, batch_size)) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + with self.test_session() as sess: + sess.run(init_op) + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "number of elements does not match"): + sess.run(get_next) + + +class BatchDatasetSerializationTest( + dataset_serialization_test_base.DatasetSerializationTestBase): + + def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2): + components = ( + np.arange(tensor_slice_len), + np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis], + np.array(multiplier) * np.arange(tensor_slice_len)) + + return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size) + + def testCore(self): + tensor_slice_len = 8 + batch_size = 2 + num_outputs = tensor_slice_len // batch_size + self.run_core_tests( + lambda: self.build_dataset(15.0, tensor_slice_len, batch_size), + lambda: self.build_dataset(20.0, tensor_slice_len, batch_size), + num_outputs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py index 9c16eebcf5d243c81b29ee628567fff7cd2769be..765ed53618958a8c49b26e416c57be28ea3bba73 100644 --- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -35,10 +36,12 @@ class GroupByWindowTest(test.TestCase): def testSimple(self): components = np.random.randint(100, size=(200,)).astype(np.int64) - iterator = dataset_ops.Iterator.from_dataset( + iterator = ( dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x) - .apply(dataset_ops.group_by_window(lambda x: x % 2, - lambda _, xs: xs.batch(4), 4))) + .apply( + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), + 4)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -61,10 +64,10 @@ class GroupByWindowTest(test.TestCase): def testImmediateOutput(self): components = np.array( [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64) - iterator = dataset_ops.Iterator.from_dataset( + iterator = ( dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply( - dataset_ops.group_by_window(lambda x: x % 3, - lambda _, xs: xs.batch(4), 4))) + grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), + 4)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -82,10 +85,10 @@ class GroupByWindowTest(test.TestCase): def testSmallGroups(self): components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64) - iterator = dataset_ops.Iterator.from_dataset( + iterator = ( dataset_ops.Dataset.from_tensor_slices(components).apply( - dataset_ops.group_by_window(lambda x: x % 2, - lambda _, xs: xs.batch(4), 4))) + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4), + 4)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -109,10 +112,11 @@ class GroupByWindowTest(test.TestCase): padded_shapes=(tensor_shape.TensorShape([]), constant_op.constant([5], dtype=dtypes.int64) * -1)) - iterator = dataset_ops.Iterator.from_dataset( + iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply( - dataset_ops.group_by_window(lambda x, _: x % 2, reduce_func, 32))) + grouping.group_by_window(lambda x, _: x % 2, reduce_func, 32)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -133,12 +137,13 @@ class GroupByWindowTest(test.TestCase): window.padded_batch( 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),)) - iterator = dataset_ops.Iterator.from_dataset( + iterator = ( dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x)) - .apply(dataset_ops.group_by_window( + .apply(grouping.group_by_window( lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64), - reduce_func, 4))) + reduce_func, 4)) + .make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -180,11 +185,11 @@ class BucketTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn)) bucketed_dataset = input_dataset.apply( - dataset_ops.group_by_window( + grouping.group_by_window( lambda x, y, z: 0, lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) + iterator = bucketed_dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -215,11 +220,11 @@ class BucketTest(test.TestCase): dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn)) bucketed_dataset = input_dataset.apply( - dataset_ops.group_by_window( + grouping.group_by_window( lambda x, y, z: math_ops.cast(x % 2, dtypes.int64), lambda k, bucket: self._dynamicPad(k, bucket, 32), 32)) - iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) + iterator = bucketed_dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -285,11 +290,11 @@ class BucketTest(test.TestCase): .filter(lambda d: math_ops.equal(d["x"] % 2, 0))) bucketed_dataset = input_dataset.apply( - dataset_ops.group_by_window( + grouping.group_by_window( lambda d: math_ops.cast(d["x"] % 2, dtypes.int64), lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32)) - iterator = dataset_ops.Iterator.from_dataset(bucketed_dataset) + iterator = bucketed_dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -320,10 +325,9 @@ class BucketTest(test.TestCase): return window_sizes[key] dataset = dataset_ops.Dataset.from_tensor_slices(components).apply( - dataset_ops.group_by_window( - lambda x: x % 2, lambda _, xs: xs.batch(20), None, - window_size_func)) - iterator = dataset_ops.Iterator.from_dataset(dataset) + grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20), + None, window_size_func)) + iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py index 364c1be8eafccb77d4a54241ed758fc6cadbd00b..9818020680afb9d0f0197d272ec5339c6358db36 100644 --- a/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/cache_dataset_op_test.py @@ -24,6 +24,7 @@ import tempfile import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -59,8 +60,8 @@ class FilesystemCacheDatasetTest(test.TestCase): # Create initialization ops for iterators without and with # caching, respectively. - iterator = dataset_ops.Iterator.from_structure(cache_dataset.output_types, - cache_dataset.output_shapes) + iterator = iterator_ops.Iterator.from_structure(cache_dataset.output_types, + cache_dataset.output_shapes) init_fifo_op = iterator.make_initializer(repeat_dataset) init_cache_op = iterator.make_initializer(cache_dataset) diff --git a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py index a77f3232ceb5bb34a3c35711d0d1cad13fbe2e0b..870352209a08e6bc08bcca227ba455ad1851e8bf 100644 --- a/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/concatenate_dataset_op_test.py @@ -17,13 +17,17 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.python.data.util import nest from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ConcatenateDatasetTest(test.TestCase): @@ -129,6 +133,140 @@ class ConcatenateDatasetTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "have different types"): input_dataset.concatenate(dataset_to_concatenate) + def _iterator_checkpoint_prefix(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _build_graph(self, input_components, to_concatenate_components): + input_dataset = dataset_ops.Dataset.from_tensor_slices(input_components) + dataset_to_concatenate = dataset_ops.Dataset.from_tensor_slices( + to_concatenate_components) + iterator = input_dataset.concatenate( + dataset_to_concatenate).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + saveable = iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + # TODO(shivaniagrawal) : non-intuitive way, add support in mata_graph + for t in nest.flatten(get_next): + ops.add_to_collection("get_next", t) + return init_op, get_next + + def _testSaveRestoreUtility(self, start, break_range, stop): + path = self._iterator_checkpoint_prefix() + step = 0 + meta_filename = path + "-%d.meta" % step + + input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( + np.array([[12], [13], [14], [15]]), 4)) + to_concatenate_components = (np.tile( + np.array([[5], [6], [7], [8], [9]]), 20), np.tile( + np.array([[16], [17], [18], [19], [20]]), 15)) + + with ops.Graph().as_default() as g: + init_op, get_next = self._build_graph(input_components, + to_concatenate_components) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for i in range(start, break_range): + result = sess.run(get_next) + if i < 4: + for component, result_component in zip(input_components, result): + self.assertAllEqual(component[i], result_component) + else: + for component, result_component in zip(to_concatenate_components, + result): + self.assertAllEqual(component[i - 4], result_component) + saver.save(sess, path, step) + + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + with self.test_session(graph=g) as sess: + get_next = nest.pack_sequence_as(("a", "b"), + ops.get_collection("get_next")) + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_range, stop): + result = sess.run(get_next) + if i < 4: + for component, result_component in zip(input_components, result): + self.assertAllEqual(component[i], result_component) + else: + for component, result_component in zip(to_concatenate_components, + result): + self.assertAllEqual(component[i - 4], result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreAtFirstDataset(self): + start = 0 + stop = 9 + break_range = 3 + self._testSaveRestoreUtility(start, break_range, stop) + + def testRestoreAtSecondDataset(self): + start = 0 + stop = 9 + break_range = 6 + self._testSaveRestoreUtility(start, break_range, stop) + + def testRestoreAtBetweenDatasets(self): + start = 0 + stop = 9 + break_range = 4 + self._testSaveRestoreUtility(start, break_range, stop) + + def testRestoreExhaustedIterator(self): + start = 0 + stop = 9 + break_range = 9 + self._testSaveRestoreUtility(start, break_range, stop) + + def testRestoreInModifiedGraph(self): + start = 0 + stop = 9 + break_range = 6 + path = self._iterator_checkpoint_prefix() + step = 0 + + input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( + np.array([[12], [13], [14], [15]]), 4)) + to_concatenate_components = (np.tile( + np.array([[5], [6], [7], [8], [9]]), 20), np.tile( + np.array([[16], [17], [18], [19], [20]]), 15)) + + with ops.Graph().as_default() as g: + init_op, get_next = self._build_graph(input_components, + to_concatenate_components) + saver = saver_lib.Saver(allow_empty=True) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for i in range(start, break_range): + result = sess.run(get_next) + if i < 4: + for component, result_component in zip(input_components, result): + self.assertAllEqual(component[i], result_component) + else: + for component, result_component in zip(to_concatenate_components, + result): + self.assertAllEqual(component[i - 4], result_component) + saver.save(sess, path, step) + + new_to_concatenate_components = (np.array([[5], [6], [7], [8], [9]]), + np.array([[16], [17], [18], [19], [20]])) + with ops.Graph().as_default() as g: + init_op, get_next = self._build_graph(input_components, + new_to_concatenate_components) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_range, stop): + result = sess.run(get_next) + for component, result_component in zip(to_concatenate_components, + result): + self.assertAllEqual(component[i - 4], result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py index acbd117a3312ee0374dc6fff215fe2c22db2e366..c3d6bfc097798530008f186cce68906b6af8fe47 100644 --- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import threading import numpy as np +from tensorflow.contrib.data.python.ops import batching from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import iterator_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.data.util import nest @@ -33,6 +36,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class DatasetConstructorTest(test.TestCase): @@ -433,6 +437,30 @@ class DatasetConstructorTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testFromGeneratorImplicitConversion(self): + def generator(): + yield [1] + yield [2] + yield [3] + + for dtype in [dtypes.int8, dtypes.int32, dtypes.int64]: + iterator = (dataset_ops.Dataset.from_generator( + generator, output_types=dtype, output_shapes=[1]) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + + self.assertEqual(dtype, get_next.dtype) + + with self.test_session() as sess: + sess.run(init_op) + for expected in [[1], [2], [3]]: + next_val = sess.run(get_next) + self.assertEqual(dtype.as_numpy_dtype, next_val.dtype) + self.assertAllEqual(expected, next_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testFromGeneratorTypeError(self): def generator(): yield np.array([1, 2, 3], dtype=np.int64) @@ -450,7 +478,7 @@ class DatasetConstructorTest(test.TestCase): sess.run(init_op) self.assertAllEqual([1, 2, 3], sess.run(get_next)) self.assertAllEqual([4, 5, 6], sess.run(get_next)) - with self.assertRaisesOpError(r"element of type .*int64.* was expected"): + with self.assertRaisesOpError(r"invalid literal for long\(\)"): sess.run(get_next) self.assertAllEqual([7, 8, 9], sess.run(get_next)) with self.assertRaises(errors.OutOfRangeError): @@ -523,8 +551,7 @@ class DatasetConstructorTest(test.TestCase): for new_types, new_shape_lists in test_cases: # pylint: disable=protected-access - new = dataset_ops._RestructuredDataset( - dataset, new_types, new_shape_lists) + new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access self.assertEqual(new_types, new.output_types) if new_shape_lists is not None: @@ -544,10 +571,139 @@ class DatasetConstructorTest(test.TestCase): for new_types, new_shape_lists in fail_cases: with self.assertRaises(ValueError): # pylint: disable=protected-access - new = dataset_ops._RestructuredDataset( - dataset, new_types, new_shape_lists) + new = batching._RestructuredDataset(dataset, new_types, new_shape_lists) # pylint: enable=protected-access + def _iterator_checkpoint_prefix(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _testSaveRestoreFromTensorsUtility(self, start, break_range, stop): + path = self._iterator_checkpoint_prefix() + step = 0 + meta_filename = path + "-%d.meta" % step + + components = (np.array(1), np.array([1, 2, 3]), np.array(37.0)) + + with ops.Graph().as_default() as g: + iterator = ( + dataset_ops.Dataset.from_tensors(components) + .make_initializable_iterator()) + init_op = iterator.initializer + get_next = iterator.get_next() + saveable = iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + for t in nest.flatten(get_next): + ops.add_to_collection("get_next", t) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(start, break_range): + result = sess.run(get_next) + for component, result_component in zip(components, result): + self.assertAllEqual(component, result_component) + saver.save(sess, path, step) + + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + with self.test_session(graph=g) as sess: + get_next = nest.pack_sequence_as(("a", "b", "c"), + ops.get_collection("get_next")) + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for _ in range(break_range, stop): + result = sess.run(get_next) + for component, result_component in zip(components, result): + self.assertAllEqual(component, result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreFromTensors(self): + self._testSaveRestoreFromTensorsUtility(0, 0, 1) + + def testRestoreExhuatedIteratorFromTensors(self): + self._testSaveRestoreFromTensorsUtility(0, 1, 1) + + def _build_graph_tensor_slices(self, components): + iterator = dataset_ops.Dataset.from_tensor_slices( + components).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + saveable = iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + for t in nest.flatten(get_next): + ops.add_to_collection("get_next", t) + return init_op, get_next + + def _testSaveRestoreFromTensorSlicesUtility(self, start, break_range, stop): + path = self._iterator_checkpoint_prefix() + step = 0 + meta_filename = path + "-%d.meta" % step + + components = (np.tile(np.array([[1], [2], [3], [4]]), 20), np.tile( + np.array([[12], [13], [14], [15]]), 22), + np.array([37.0, 38.0, 39.0, 40.0])) + + with ops.Graph().as_default() as g: + init_op, get_next = self._build_graph_tensor_slices(components) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for i in range(start, break_range): + result = sess.run(get_next) + for component, result_component in zip(components, result): + self.assertAllEqual(component[i], result_component) + saver.save(sess, path, step) + + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + with self.test_session(graph=g) as sess: + get_next = nest.pack_sequence_as(("a", "b", "c"), + ops.get_collection("get_next")) + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_range, stop): + result = sess.run(get_next) + for component, result_component in zip(components, result): + self.assertAllEqual(component[i], result_component) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreFromTensorSlices(self): + self._testSaveRestoreFromTensorSlicesUtility(0, 4, 2) + + def testRestoreExhaustedIteratorFromTensorSlices(self): + self._testSaveRestoreFromTensorSlicesUtility(0, 4, 4) + + def tesRestoreFromTensorSlicesWithDict(self): + + path = self._iterator_checkpoint_prefix() + step = 0 + meta_filename = path + "-%d.meta" % step + + components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]} + + with ops.Graph().as_default() as g: + init_op, get_next = self._build_graph_tensor_slices(components) + saver = saver_lib.Saver() + with self.test_session(graph=g) as sess: + sess.run(init_op) + for i in range(2): + results = sess.run(get_next) + self.assertEqual(components["foo"][i], results["foo"]) + self.assertEqual(components["bar"][i], results["bar"]) + saver.save(sess, path, step) + + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + with self.test_session(graph=g) as sess: + get_next = nest.pack_sequence_as(("a", "b"), + ops.get_collection("get_next")) + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(2, 3): + results = sess.run(get_next) + self.assertEqual(components["foo"][i], results["foo"]) + self.assertEqual(components["bar"][i], results["bar"]) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py new file mode 100644 index 0000000000000000000000000000000000000000..8713640985b1e23da378603af265eec894023e34 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/dataset_serialization_test_base.py @@ -0,0 +1,405 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Base class for testing serializable datasets.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np + +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.util import nest + + +class DatasetSerializationTestBase(test.TestCase): + """Base class for testing finite serializable datasets.""" + + def tearDown(self): + self._delete_ckpt() + + def run_core_tests(self, ds_fn1, ds_fn2, num_outputs): + """Runs the core tests. + + Args: + ds_fn1: 0-argument function that returns a Dataset. + ds_fn2: 0-argument function that returns a Dataset different from + ds_fn1. If None, verify_restore_in_modified_graph test is not run. + num_outputs: Total number of outputs expected from this Dataset. + + Raises: + AssertionError if any test fails. + """ + self.verify_unused_iterator(ds_fn1, num_outputs) + self.verify_fully_used_iterator(ds_fn1, num_outputs) + self.verify_exhausted_iterator(ds_fn1, num_outputs) + self.verify_init_before_restore(ds_fn1, num_outputs) + self.verify_multiple_breaks(ds_fn1, num_outputs) + self.verify_reset_restored_iterator(ds_fn1, num_outputs) + if ds_fn2: + self.verify_restore_in_modified_graph(ds_fn1, ds_fn2, num_outputs) + + def verify_unused_iterator(self, ds_fn, num_outputs): + """Verifies that saving and restoring an unused iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks(ds_fn, [0], num_outputs) + + def verify_fully_used_iterator(self, ds_fn, num_outputs): + """Verifies that saving and restoring a fully used iterator works. + + Note that this only checks saving and restoring an iterator from which + `num_outputs` items have been produced but does not check for an + exhausted iterator, i.e., one from which an OutOfRange error has been + returned. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + + Raises: + AssertionError if test fails. + """ + self.verify_run_with_breaks(ds_fn, [num_outputs], num_outputs) + + def verify_exhausted_iterator(self, ds_fn, num_outputs): + """Verifies that saving and restoring an exhausted iterator works. + + An exhausted iterator is one which has returned an OutOfRange error. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + actual = self.gen_outputs( + ds_fn, [], 0, ckpt_saved=True, verify_exhausted=True) + self.assertEqual(len(actual), 0) + + def verify_init_before_restore(self, ds_fn, num_outputs): + """Verifies that retoring into an already initilized iterator works. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks( + ds_fn, + self.gen_break_points(num_outputs), + num_outputs, + init_before_restore=True) + + def verify_multiple_breaks(self, ds_fn, num_outputs, num_breaks=10): + """Attempts to save/restore at multiple break points. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + num_breaks: The number of break points. These are uniformly spread in + [0, num_outputs] both inclusive. + + Raises: + AssertionError if any test fails. + """ + self.verify_run_with_breaks(ds_fn, + self.gen_break_points(num_outputs, num_breaks), + num_outputs) + + def verify_reset_restored_iterator(self, ds_fn, num_outputs, + break_point=None): + """Attempts to re-initialize a restored iterator. + + This is useful when restoring a training checkpoint during validation. + + Args: + ds_fn: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Collect ground truth containing all outputs. + expected = self.gen_outputs(ds_fn, [], num_outputs, verify_exhausted=True) + + # Skip some items and save checkpoint. + self.gen_outputs(ds_fn, [], break_point, verify_exhausted=False) + + actual = [] + # Restore from checkpoint and then run init_op. + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(init_op) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.match(expected, actual) + + def verify_restore_in_modified_graph(self, + ds_fn1, + ds_fn2, + num_outputs, + break_point=None): + """Attempts to restore an iterator in a modified graph. + + Builds an input pipeline using ds_fn1, runs it for `break_point` steps + and saves a checkpoint. Then builds a new graph using ds_fn2, restores + the checkpoint from ds_fn1 and verifies that the restore is successful. + + Args: + ds_fn1: See `run_core_tests`. + ds_fn2: See `run_core_tests`. + num_outputs: See `run_core_tests`. + break_point: Break point. Optional. Defaults to num_outputs/2. + + Raises: + AssertionError if any test fails. + """ + break_point = num_outputs // 2 if not break_point else break_point + + # Skip `break_point` items and store the remaining produced from ds_fn1 + # in `expected`. + self.gen_outputs(ds_fn1, [], break_point) + expected = self.gen_outputs( + ds_fn1, [], + num_outputs - break_point, + ckpt_saved=True, + verify_exhausted=True) + + # Generate `break_point` items from ds_fn1 and save checkpoint. + self.gen_outputs(ds_fn1, [], break_point) + + # Build graph for ds_fn2 but load checkpoint for ds_fn1. + actual = self.gen_outputs( + ds_fn2, [], break_point, ckpt_saved=True, verify_exhausted=True) + + self.match(expected, actual) + + def verify_run_with_breaks(self, + ds_fn, + break_points, + num_outputs, + init_before_restore=False): + """Verifies that ds_fn() produces the same outputs with and without breaks. + + 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + *without* stopping at break points. + 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it + with stopping at break points. + + Deep matches outputs from 1 and 2. + + Args: + ds_fn: See `gen_outputs`. + break_points: See `gen_outputs`. + num_outputs: See `gen_outputs`. + init_before_restore: See `gen_outputs`. + + Raises: + AssertionError if any test fails. + """ + expected = self.gen_outputs( + ds_fn, [], + num_outputs, + verify_exhausted=True, + init_before_restore=init_before_restore) + actual = self.gen_outputs( + ds_fn, + break_points, + num_outputs, + verify_exhausted=True, + init_before_restore=init_before_restore) + self.match(expected, actual) + + def gen_outputs(self, + ds_fn, + break_points, + num_outputs, + ckpt_saved=False, + init_before_restore=False, + verify_exhausted=False): + """Generates elements from input dataset while stopping at break points. + + Produces `num_outputs` outputs and saves the state of the iterator in the + Saver checkpoint. + + Args: + ds_fn: 0-argument function that returns the dataset. + break_points: A list of integers. For each `break_point` in + `break_points`, we produce outputs till `break_point` number of items + have been produced and then checkpoint the state. The current graph + and session are destroyed and a new graph and session are used to + produce outputs till next checkpoint or till `num_outputs` elements + have been produced. `break_point` must be <= `num_outputs`. + num_outputs: The total number of outputs to produce from the iterator. + ckpt_saved: Whether a checkpoint already exists. If False, we build the + graph from ds_fn. + init_before_restore: Whether init should be called before saver.restore. + This is just so that we can verify that restoring an already initialized + iterator works. + verify_exhausted: Whether to verify that the iterator has been exhausted + after producing `num_outputs` elements. + + Returns: + A list if `num_outputs` items. + """ + outputs = [] + + def get_ops(): + if ckpt_saved: + saver = self._import_meta_graph() + init_op, get_next_op = self._get_iterator_ops_from_collection(ds_fn) + else: + init_op, get_next_op, saver = self._build_graph(ds_fn) + return init_op, get_next_op, saver + + for i in range(len(break_points) + 1): + with ops.Graph().as_default() as g: + init_op, get_next_op, saver = get_ops() + with self.test_session(graph=g) as sess: + if ckpt_saved: + if init_before_restore: + sess.run(init_op) + self._restore(saver, sess) + else: + sess.run(init_op) + start = break_points[i - 1] if i > 0 else 0 + end = break_points[i] if i < len(break_points) else num_outputs + num_iters = end - start + for _ in range(num_iters): + outputs.append(sess.run(get_next_op)) + self._save(sess, saver) + ckpt_saved = True + if i == len(break_points) and verify_exhausted: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + return outputs + + def match(self, expected, actual): + """Matches nested structures. + + Recursively matches shape and values of `expected` and `actual`. + Handles scalars, numpy arrays and other python sequence containers + e.g. list, dict. + + Args: + expected: Nested structure 1. + actual: Nested structure 2. + + Raises: + AssertionError if matching fails. + """ + if isinstance(expected, np.ndarray): + expected = expected.tolist() + if isinstance(actual, np.ndarray): + actual = actual.tolist() + self.assertEqual(type(expected), type(actual)) + + if nest.is_sequence(expected): + self.assertEqual(len(expected), len(actual)) + if isinstance(expected, dict): + for key1, key2 in sorted(expected, actual): + self.assertEqual(key1, key2) + self.match(expected[key1], actual[key2]) + else: + for item1, item2 in zip(expected, actual): + self.match(item1, item2) + else: + self.assertEqual(expected, actual) + + def does_not_match(self, expected, actual): + with self.assertRaises(AssertionError): + self.match(expected, actual) + + def gen_break_points(self, num_outputs, num_samples=10): + """Generates `num_samples` breaks points in [0, num_outputs].""" + return np.linspace(0, num_outputs, num_samples, dtype=int) + + def _build_graph(self, ds_fn): + iterator = ds_fn().make_initializable_iterator() + + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + self._add_iterator_ops_to_collection(init_op, get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _add_iterator_ops_to_collection(self, init_op, get_next): + ops.add_to_collection("iterator_ops", init_op) + # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections + # do not support tuples we flatten the tensors and restore the shape in + # `_get_iterator_ops_from_collection`. + for el in nest.flatten(get_next): + ops.add_to_collection("iterator_ops", el) + + def _get_iterator_ops_from_collection(self, ds_fn): + all_ops = ops.get_collection("iterator_ops") + return all_ops[0], nest.pack_sequence_as( + self._get_output_types(ds_fn), all_ops[1:]) + + def _get_output_types(self, ds_fn): + with ops.Graph().as_default(): + return ds_fn().output_types + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _delete_ckpt(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) diff --git a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py similarity index 84% rename from tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py rename to tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py index 880e01dc069a70ac4ccbbbc18865f631ddea74d8..0aa9ea88de82b0851b0236d9412039d6573ab291 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sloppy_transformation_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py @@ -25,7 +25,7 @@ import time from six.moves import zip_longest from tensorflow.contrib.data.python.ops import dataset_ops -from tensorflow.contrib.data.python.ops import sloppy_ops +from tensorflow.contrib.data.python.ops import interleave_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -34,12 +34,13 @@ from tensorflow.python.ops import script_ops from tensorflow.python.platform import test -class SloppyInterleaveDatasetTest(test.TestCase): +class ParallelInterleaveDatasetTest(test.TestCase): def setUp(self): self.input_values = array_ops.placeholder(dtypes.int64, shape=[None]) self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[]) self.block_length = array_ops.placeholder(dtypes.int64, shape=[]) + self.sloppy = array_ops.placeholder(dtypes.bool, shape=[]) self.repeat_count = 2 @@ -69,9 +70,9 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.dataset = (dataset_ops.Dataset.from_tensor_slices(self.input_values) .repeat(self.repeat_count).apply( - sloppy_ops.sloppy_interleave( + interleave_ops.parallel_interleave( interleave_fn, self.cycle_length, - self.block_length))) + self.block_length, self.sloppy))) self.iterator = self.dataset.make_initializable_iterator() self.init_op = self.iterator.initializer self.next_element = self.iterator.get_next() @@ -161,7 +162,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): for i in range(4, 7): self.write_coordination_events[i].set() - def testSingleThreaded(self): + def _testSingleThreaded(self, sloppy=False): # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and # `Dataset.flat_map()` and is single-threaded. No synchronization required. with self.test_session() as sess: @@ -171,7 +172,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 1, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy }) for expected_element in self._interleave( @@ -182,7 +184,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContention(self): + def testSingleThreaded(self): + self._testSingleThreaded() + + def testSingleThreadedSloppy(self): + self._testSingleThreaded(sloppy=True) + + def _testTwoThreadsNoContention(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior with self.test_session() as sess: @@ -193,7 +201,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -211,11 +220,20 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionWithRaces(self): + def testTwoThreadsNoContention(self): + self._testTwoThreadsNoContention() + + def testTwoThreadsNoContentionSloppy(self): + self._testTwoThreadsNoContention(sloppy=True) + + def _testTwoThreadsNoContentionWithRaces(self, sloppy=False): """Tests where all the workers race in producing elements. Note: this is in contrast with the prevous test which carefully sequences the execution of the map functions. + + Args: + sloppy: Whether to be sloppy or not. """ with self.test_session() as sess: self._clear_coordination_events() @@ -225,7 +243,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -247,7 +266,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionBlockLength(self): + def testTwoThreadsNoContentionWithRaces(self): + self._testTwoThreadsNoContentionWithRaces() + + def testTwoThreadsNoContentionWithRacesSloppy(self): + self._testTwoThreadsNoContentionWithRaces(sloppy=True) + + def _testTwoThreadsNoContentionBlockLength(self, sloppy=False): # num_threads > 1. # Explicit coordination should result in `Dataset.interleave()` behavior with self.test_session() as sess: @@ -258,7 +283,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -276,11 +302,21 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testTwoThreadsNoContentionWithRacesAndBlocking(self): + def testTwoThreadsNoContentionBlockLength(self): + self._testTwoThreadsNoContentionBlockLength() + + def testTwoThreadsNoContentionBlockLengthSloppy(self): + self._testTwoThreadsNoContentionBlockLength(sloppy=True) + + def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False): """Tests where all the workers race in producing elements. Note: this is in contrast with the prevous test which carefully sequences the execution of the map functions. + + + Args: + sloppy: Whether to be sloppy or not. """ with self.test_session() as sess: self._clear_coordination_events() @@ -290,7 +326,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i, expected_element in enumerate( self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2, @@ -312,7 +349,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testEmptyInput(self): + def testTwoThreadsNoContentionWithRacesAndBlocking(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking() + + def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self): + self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True) + + def _testEmptyInput(self, sloppy=False): with self.test_session() as sess: # Empty input. self._clear_coordination_events() @@ -321,12 +364,19 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: sloppy }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testNonEmptyInputIntoEmptyOutputs(self): + def testEmptyInput(self): + self._testEmptyInput() + + def testEmptyInputSloppy(self): + self._testEmptyInput(sloppy=True) + + def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False): # Non-empty input leading to empty output. with self.test_session() as sess: self._clear_coordination_events() @@ -335,12 +385,19 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [0, 0, 0], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: sloppy }) with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testPartiallyEmptyOutputs(self): + def testNonEmptyInputIntoEmptyOutputs(self): + self._testNonEmptyInputIntoEmptyOutputs() + + def testNonEmptyInputIntoEmptyOutputsSloppy(self): + self._testNonEmptyInputIntoEmptyOutputs(sloppy=True) + + def _testPartiallyEmptyOutputs(self, sloppy=False): # Mixture of non-empty and empty interleaved datasets. with self.test_session() as sess: self._clear_coordination_events() @@ -350,7 +407,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 0, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: sloppy, }) for i, expected_element in enumerate( self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)): @@ -367,7 +425,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testDelayedOutput(self): + def testPartiallyEmptyOutputs(self): + self._testPartiallyEmptyOutputs() + + def testPartiallyEmptyOutputsSloppy(self): + self._testPartiallyEmptyOutputs(sloppy=True) + + def testDelayedOutputSloppy(self): # Explicitly control the sequence of events to ensure we correctly avoid # head-of-line blocking. with self.test_session() as sess: @@ -377,7 +441,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 1 + self.block_length: 1, + self.sloppy: True, }) mis_ordering = [ @@ -391,7 +456,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testBlockLengthWithContention(self): + def testBlockLengthWithContentionSloppy(self): with self.test_session() as sess: self._clear_coordination_events() done_first_event = False @@ -400,7 +465,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 2, - self.block_length: 3 + self.block_length: 3, + self.sloppy: True }) # Test against a generating sequence that differs from the uncontended # case, in order to prove sloppy correctness. @@ -422,7 +488,7 @@ class SloppyInterleaveDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(self.next_element) - def testEarlyExit(self): + def _testEarlyExit(self, sloppy=False): # Exiting without consuming all input should not block with self.test_session() as sess: self._clear_coordination_events() @@ -431,7 +497,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): feed_dict={ self.input_values: [4, 5, 6], self.cycle_length: 3, - self.block_length: 2 + self.block_length: 2, + self.sloppy: sloppy }) for i in range(4, 7): self.write_coordination_events[i].set() @@ -445,7 +512,13 @@ class SloppyInterleaveDatasetTest(test.TestCase): self.read_coordination_events[i].acquire() self.write_coordination_events[i].set() - def testTooManyReaders(self): + def testEarlyExit(self): + self._testEarlyExit() + + def testEarlyExitSloppy(self): + self._testEarlyExit(sloppy=True) + + def _testTooManyReaders(self, sloppy=False): def interleave_fn(x): dataset = dataset_ops.Dataset.from_tensors(x) @@ -455,8 +528,8 @@ class SloppyInterleaveDatasetTest(test.TestCase): dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6]) dataset = dataset.repeat(self.repeat_count) dataset = dataset.apply( - sloppy_ops.sloppy_interleave(interleave_fn, cycle_length=16, - block_length=2)) + interleave_ops.parallel_interleave( + interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy)) iterator = dataset.make_one_shot_iterator() with self.test_session() as sess: @@ -468,6 +541,11 @@ class SloppyInterleaveDatasetTest(test.TestCase): [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2) self.assertItemsEqual(output_values, expected_values) + def testTooManyReaders(self): + self._testTooManyReaders() + + def testTooManyReadersSloppy(self): + self._testTooManyReaders(sloppy=True) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py index faad6e925d78e273d4c308d42598aa12edc792e2..02379d064d4ab857ce9c7d13881a3ae37eea0980 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_cluster_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import function @@ -44,7 +45,7 @@ class IteratorClusterTest(test.TestCase): iterator_3_handle = iterator_3.string_handle() with ops.device("/job:worker/replica:0/task:0/cpu:0"): - remote_it = dataset_ops.Iterator.from_string_handle( + remote_it = iterator_ops.Iterator.from_string_handle( iterator_3_handle, dataset_3.output_types, dataset_3.output_shapes) get_next_op = remote_it.get_next() @@ -52,24 +53,19 @@ class IteratorClusterTest(test.TestCase): with self.assertRaises(errors.InvalidArgumentError): sess.run(get_next_op) - def testRemoteIteratorUsingRemoteCallOp(self): - worker_config = config_pb2.ConfigProto() - worker_config.device_count["CPU"] = 2 - worker, _ = test_util.create_local_cluster( - 1, 1, worker_config=worker_config) - - with ops.device("/job:worker/replica:0/task:0/cpu:1"): + def _testRemoteIteratorHelper(self, device0, device1, target): + with ops.device(device1): dataset_3 = dataset_ops.Dataset.from_tensor_slices([1, 2, 3]) iterator_3 = dataset_3.make_one_shot_iterator() iterator_3_handle = iterator_3.string_handle() @function.Defun(dtypes.string) def _remote_fn(h): - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() - with ops.device("/job:worker/replica:0/task:0/cpu:0"): + with ops.device(device0): target_placeholder = array_ops.placeholder(dtypes.string, shape=[]) remote_op = functional_ops.remote_call( args=[iterator_3_handle], @@ -77,32 +73,35 @@ class IteratorClusterTest(test.TestCase): f=_remote_fn, target=target_placeholder) - with session.Session(worker[0].target) as sess: - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + with session.Session(target) as sess: + elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) self.assertEqual(elem, [1]) # Fails when target is cpu:0 where the resource is not located. with self.assertRaises(errors.InvalidArgumentError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:worker/replica:0/task:0/cpu:0" - }) - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + sess.run(remote_op, feed_dict={target_placeholder: device0}) + elem = sess.run(iterator_3.get_next()) self.assertEqual(elem, [2]) - elem = sess.run( - remote_op, - feed_dict={target_placeholder: "/job:worker/replica:0/task:0/cpu:1"}) + elem = sess.run(remote_op, feed_dict={target_placeholder: device1}) self.assertEqual(elem, [3]) with self.assertRaises(errors.OutOfRangeError): - sess.run( - remote_op, - feed_dict={ - target_placeholder: "/job:worker/replica:0/task:0/cpu:1" - }) + sess.run(remote_op, feed_dict={target_placeholder: device1}) + + def testRemoteIteratorUsingRemoteCallOp(self): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + worker, _ = test_util.create_local_cluster( + 1, 1, worker_config=worker_config) + + self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", + "/job:worker/replica:0/task:0/cpu:1", + worker[0].target) + + def testRemoteIteratorUsingRemoteCallOpCrossProcess(self): + workers, _ = test_util.create_local_cluster(2, 1) + + self._testRemoteIteratorHelper("/job:worker/replica:0/task:0/cpu:0", + "/job:worker/replica:0/task:1/cpu:0", + workers[0].target) if __name__ == "__main__": diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py index 87e83b8d12bc4986c25185cd142155c789603545..bda9a2a4a37e9c3d35ff99041d1150ffc43f4c43 100644 --- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py @@ -17,11 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import readers from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -30,7 +33,9 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import functional_ops +from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import io_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import script_ops @@ -236,7 +241,7 @@ class IteratorTest(test.TestCase): # functions in this graph, to ensure that we are not # accidentally redefining functions with the same names in the # new graph. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( shared_name="shared_iterator", output_types=(dtypes.int64, dtypes.int64, dtypes.float64), output_shapes=([], [3], [])) @@ -266,8 +271,8 @@ class IteratorTest(test.TestCase): constant_op.constant([1, 2, 3])) dataset_4 = dataset_ops.Dataset.from_tensors( constant_op.constant([4, 5, 6, 7])) - iterator = dataset_ops.Iterator.from_structure(dataset_3.output_types, - [None]) + iterator = iterator_ops.Iterator.from_structure(dataset_3.output_types, + [None]) dataset_3_init_op = iterator.make_initializer(dataset_3) dataset_4_init_op = iterator.make_initializer(dataset_4) @@ -303,12 +308,12 @@ class IteratorTest(test.TestCase): def testReinitializableIteratorStaticErrors(self): # Non-matching structure for types and shapes. with self.assertRaises(TypeError): - iterator = dataset_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64), [None]) + iterator = iterator_ops.Iterator.from_structure((dtypes.int64, + dtypes.float64), [None]) # Test validation of dataset argument. - iterator = dataset_ops.Iterator.from_structure((dtypes.int64, - dtypes.float64)) + iterator = iterator_ops.Iterator.from_structure((dtypes.int64, + dtypes.float64)) # Incompatible structure. with self.assertRaises(ValueError): @@ -325,7 +330,7 @@ class IteratorTest(test.TestCase): [4., 5., 6., 7.], dtype=dtypes.float32)))) # Incompatible shapes. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( (dtypes.int64, dtypes.float64), ([None], [])) with self.assertRaises(TypeError): iterator.make_initializer( @@ -341,7 +346,7 @@ class IteratorTest(test.TestCase): iterator_4 = dataset_4.make_one_shot_iterator() handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_iterator = dataset_ops.Iterator.from_string_handle( + feedable_iterator = iterator_ops.Iterator.from_string_handle( handle_placeholder, dataset_3.output_types, dataset_3.output_shapes) next_element = feedable_iterator.get_next() @@ -388,11 +393,11 @@ class IteratorTest(test.TestCase): handle_placeholder = array_ops.placeholder(dtypes.string, shape=[]) - feedable_int_scalar = dataset_ops.Iterator.from_string_handle( + feedable_int_scalar = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, []) - feedable_int_vector = dataset_ops.Iterator.from_string_handle( + feedable_int_vector = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32, [None]) - feedable_int_any = dataset_ops.Iterator.from_string_handle( + feedable_int_any = iterator_ops.Iterator.from_string_handle( handle_placeholder, dtypes.int32) with self.test_session() as sess: @@ -432,7 +437,7 @@ class IteratorTest(test.TestCase): @function.Defun(dtypes.string) def _remote_fn(h): - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( h, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() @@ -492,7 +497,7 @@ class IteratorTest(test.TestCase): @function.Defun(dtypes.uint8) def _remote_fn(h): handle = script_ops.py_func(_encode_raw, [h], dtypes.string) - remote_iterator = dataset_ops.Iterator.from_string_handle( + remote_iterator = iterator_ops.Iterator.from_string_handle( handle, dataset_3.output_types, dataset_3.output_shapes) return remote_iterator.get_next() @@ -532,6 +537,89 @@ class IteratorTest(test.TestCase): target_placeholder: "/job:localhost/replica:0/task:0/cpu:0" }) + def testIncorrectIteratorRestore(self): + + def _path(): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + _path(), parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + + def _build_range_dataset_graph(): + start = 1 + stop = 10 + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + def _build_reader_dataset_graph(): + filenames = ["test"] # Does not exist but we don't care in this test. + iterator = readers.FixedLengthRecordDataset( + filenames, 1, 0, 0).make_initializable_iterator() + init_op = iterator.initializer + get_next_op = iterator.get_next() + save_op = _save_op(iterator._iterator_resource) + restore_op = _restore_op(iterator._iterator_resource) + return init_op, get_next_op, save_op, restore_op + + # Saving iterator for RangeDataset graph. + with ops.Graph().as_default() as g: + init_op, _, save_op, _ = _build_range_dataset_graph() + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(save_op) + + # Attempt to restore the saved iterator into an IteratorResource of + # incompatible type. An iterator of RangeDataset has output type int64, + # while an iterator of FixedLengthRecordDataset has output type string. + # So an InvalidArgumentError should be raised by + # IteratorResource::set_iterator. + with ops.Graph().as_default() as g: + _, _, _, restore_op = _build_reader_dataset_graph() + with self.test_session(graph=g) as sess: + with self.assertRaises(errors.InvalidArgumentError): + sess.run(restore_op) + + def testToSingleElement(self): + skip_value = array_ops.placeholder(dtypes.int64, shape=[]) + take_value = array_ops.placeholder_with_default( + constant_op.constant(1, dtype=dtypes.int64), shape=[]) + + dataset = (dataset_ops.Dataset.range(100) + .skip(skip_value) + .map(lambda x: x * x) + .take(take_value)) + + element = dataset_ops.get_single_element(dataset) + + with self.test_session() as sess: + self.assertEqual(0, sess.run(element, feed_dict={skip_value: 0})) + self.assertEqual(25, sess.run(element, feed_dict={skip_value: 5})) + self.assertEqual(100, sess.run(element, feed_dict={skip_value: 10})) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset was empty."): + sess.run(element, feed_dict={skip_value: 100}) + + with self.assertRaisesRegexp(errors.InvalidArgumentError, + "Dataset had more than one element."): + sess.run(element, feed_dict={skip_value: 0, take_value: 2}) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py index 49d3d4c260bfc564c8449c51cb5bf0ebd5838ed8..8a1d99499be702d91f87f65f443261b47ce5c5cd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py @@ -24,6 +24,7 @@ from collections import namedtuple import numpy as np +from tensorflow.contrib.data.python.ops import error_ops from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -181,7 +182,9 @@ class MapDatasetTest(test.TestCase): (1, 1), (1, 2), (2, 2), (2, 4), (8, 8), (8, 16)]: do_test(num_threads_val, output_buffer_size_val) - def _testDisposeParallelMapDataset(self, explicit_dispose): + def testImplicitDisposeParallelMapDataset(self): + # Tests whether a parallel map dataset will be cleaned up correctly when + # the pipeline does not run it until exhaustion. # The pipeline is TensorSliceDataset -> MapDataset(square_3) -> # RepeatDataset(1000). components = (np.arange(1000), @@ -194,21 +197,11 @@ class MapDatasetTest(test.TestCase): iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - if explicit_dispose: - dispose_op = iterator.dispose_op() with self.test_session() as sess: sess.run(init_op) for _ in range(3): sess.run(get_next) - if explicit_dispose: - sess.run(dispose_op) - - def testExplicitDisposeParallelMapDataset(self): - self._testDisposeParallelMapDataset(True) - - def testImplicitDisposeParallelMapDataset(self): - self._testDisposeParallelMapDataset(False) def testParallelMapUnspecifiedOutputSize(self): components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32) @@ -272,7 +265,7 @@ class MapDatasetTest(test.TestCase): dataset = (dataset_ops.Dataset.from_tensor_slices(components) .map(lambda x: array_ops.check_numerics(x, "message")).apply( - dataset_ops.ignore_errors())) + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -290,7 +283,7 @@ class MapDatasetTest(test.TestCase): dataset = (dataset_ops.Dataset.from_tensor_slices(components).map( lambda x: array_ops.check_numerics(x, "message"), num_threads=2, - output_buffer_size=2).apply(dataset_ops.ignore_errors())) + output_buffer_size=2).apply(error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() @@ -313,7 +306,7 @@ class MapDatasetTest(test.TestCase): dataset = (dataset_ops.Dataset.from_tensor_slices(filenames).map( io_ops.read_file, num_threads=2, output_buffer_size=2).apply( - dataset_ops.ignore_errors())) + error_ops.ignore_errors())) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3e38db59301bf1819999f479171af35930e9d2 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py @@ -0,0 +1,112 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for prefetching_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import threading + +from tensorflow.contrib.data.python.ops import prefetching_ops +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + +class StagingAreaOpsTest(test.TestCase): + + def setUp(self): + self._event = threading.Event() + + def _prefetch_fn_helper(self, buffer_name, device0, device1): + worker_config = config_pb2.ConfigProto() + worker_config.device_count["CPU"] = 2 + + def gen(): + for i in itertools.count(start=1, step=1): + yield [i + 0.0] + if i == 6: + self._event.set() + + with ops.device(device0): + dataset_3 = dataset_ops.Dataset.from_generator(gen, (dtypes.float32)) + iterator_3 = dataset_3.make_one_shot_iterator() + iterator_3_handle = iterator_3.string_handle() + + @function.Defun(dtypes.string) + def _remote_fn(h): + remote_iterator = iterator_ops.Iterator.from_string_handle( + h, dataset_3.output_types, dataset_3.output_shapes) + return remote_iterator.get_next() + + target = constant_op.constant(device0) + with ops.device(device1): + buffer_resource_handle = prefetching_ops.function_buffering_resource( + f=_remote_fn, + target_device=target, + string_arg=iterator_3_handle, + buffer_size=3, + thread_pool_size=2, + shared_name=buffer_name) + + with ops.device(device1): + prefetch_op = prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=buffer_resource_handle, + output_types=[dtypes.float32]) + + with self.test_session(config=worker_config) as sess: + elem = sess.run(prefetch_op) + self.assertEqual(elem, [1.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [2.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [3.0]) + elem = sess.run(prefetch_op) + self.assertEqual(elem, [4.0]) + self._event.wait() + elem = sess.run(prefetch_op) + self.assertEqual(elem, [5.0]) + sess.run( + resource_variable_ops.destroy_resource_op( + buffer_resource_handle, ignore_lookup_error=True)) + + def testSameDeviceCPU(self): + self._prefetch_fn_helper("same_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:0") + + def testDifferentDeviceCPU(self): + self._prefetch_fn_helper("diff_device_cpu", + "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/cpu:1") + + def testDifferentDeviceCPUGPU(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + self._prefetch_fn_helper("cpu_gpu", "/job:localhost/replica:0/task:0/cpu:0", + "/job:localhost/replica:0/task:0/gpu:0") + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py index faa4d187aca445af31988dde691956ce82afda19..f59ac760dc83a504e563f055b91f1002cb0c80fc 100644 --- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py @@ -18,7 +18,11 @@ from __future__ import division from __future__ import print_function import os + from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors @@ -26,9 +30,12 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops +from tensorflow.python.ops import parsing_ops from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class RangeDatasetTest(test.TestCase): @@ -170,7 +177,7 @@ class RangeDatasetTest(test.TestCase): start = constant_op.constant(20, dtype=dtypes.int64) iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply( - dataset_ops.enumerate_dataset(start)).make_initializable_iterator()) + enumerate_ops.enumerate_dataset(start)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() @@ -190,6 +197,21 @@ class RangeDatasetTest(test.TestCase): def _iterator_checkpoint_prefix(self): return os.path.join(self.get_temp_dir(), "iterator") + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_prefix(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_prefix()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def testSaveRestore(self): def _build_graph(start, stop): @@ -197,10 +219,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op # Saving and restoring in different sessions. @@ -241,6 +261,257 @@ class RangeDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) + def testSaveRestoreUsingSaverFromMetaGraph(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + break_point = 5 + path = self._iterator_checkpoint_prefix() + meta_filename = path + ".meta" + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Build the saver from the MetaGraph using import_meta_graph and + # check that the iterator state is restored. + with ops.Graph().as_default() as g: + saver = saver_lib.import_meta_graph(meta_filename) + init_op, get_next = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreUsingBuiltSaver(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + stop_new = 15 + break_point = 5 + path = self._iterator_checkpoint_prefix() + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Manually build a modified Graph and Saver instead of importing + # MetaGraph and verify that original iterator state gets restored. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop_new) + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testSaveRestoreUsingSaverThenInit(self): + + def _build_graph(start, stop): + iterator = dataset_ops.Dataset.range(start, + stop).make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + # Add the SaveableObject to the `SAVEABLE_OBJECTS` collection + # so that it can be automatically picked up by the Saver. + saveable_obj = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = saver_lib.Saver() + return init_op, get_next, saver + + start = 2 + stop = 10 + stop_new = 15 + break_point = 5 + path = self._iterator_checkpoint_prefix() + + # Execute input pipeline for a few steps and save iterator state. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + saver.save(sess, path) + + # Restore iterator state call and then call init_op for the iterator and + # verify that the new iterator hides the restored iterator. + with ops.Graph().as_default() as g: + init_op, get_next, saver = _build_graph(start, stop_new) + with self.test_session(graph=g) as sess: + saver.restore(sess, saver_lib.latest_checkpoint(self.get_temp_dir())) + sess.run(init_op) + for i in range(start, stop_new): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreWithoutBuildingDatasetGraph(self): + + def _build_graph(start, stop, num_epochs): + dataset = dataset_ops.Dataset.range(start, stop).repeat(num_epochs) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + num_epochs = 5 + break_point = 5 + break_epoch = 3 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop, num_epochs) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for _ in range(break_epoch): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Create an empty IteratorResource and restore the Iterator into it. + output_types = dtypes.int64 + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, + output_shapes) + restore_op = self._restore_op(iterator._iterator_resource) + get_next = iterator.get_next() + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + for _ in range(break_epoch + 1, num_epochs): + for i in range(start, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testRestoreInModifiedGraph(self): + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + stop_1 = 8 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + # Intentionally build a graph with a different value for stop to make sure + # the original dataset graph is actually getting loaded. + init_op, get_next, _, restore_op = _build_graph(start, stop_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + + def _build_graph(start, stop): + dataset = dataset_ops.Dataset.range(start, stop) + iterator = dataset.make_initializable_iterator() + init_op = iterator.initializer + get_next = iterator.get_next() + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) + return init_op, get_next, save_op, restore_op + + # Saving and restoring in different sessions. + start = 2 + stop = 10 + break_point = 5 + with ops.Graph().as_default() as g: + init_op, get_next, save_op, _ = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(variables.global_variables_initializer()) + sess.run(init_op) + for i in range(start, break_point): + self.assertEqual(i, sess.run(get_next)) + sess.run(save_op) + + with ops.Graph().as_default() as g: + init_op, get_next, _, restore_op = _build_graph(start, stop) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for i in range(break_point, stop): + self.assertEqual(i, sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + def testMultipleSaves(self): def _build_graph(start, stop): @@ -248,10 +519,8 @@ class RangeDatasetTest(test.TestCase): stop).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -271,7 +540,6 @@ class RangeDatasetTest(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next, save_op, restore_op = _build_graph(start, stop) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) for i in range(break_point1, break_point2): self.assertEqual(i, sess.run(get_next)) @@ -281,7 +549,6 @@ class RangeDatasetTest(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next, save_op, restore_op = _build_graph(start, stop) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) for i in range(break_point2, stop): self.assertEqual(i, sess.run(get_next)) @@ -295,10 +562,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -326,7 +591,6 @@ class RangeDatasetTest(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) for i in range(break_range, stop): self.assertEqual(i, sess.run(get_next)) @@ -343,10 +607,8 @@ class RangeDatasetTest(test.TestCase): start, stop).repeat(num_epochs).make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() - path = self._iterator_checkpoint_prefix() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next, save_op, restore_op start = 2 @@ -372,7 +634,6 @@ class RangeDatasetTest(test.TestCase): with ops.Graph().as_default() as g: init_op, get_next, _, restore_op = _build_graph(start, stop, num_epochs) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next) diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py index d631fbc76e3b24e0b121d12451e25a294e765324..3ae8f71d77fa6ecf08e42bedac702b8f75eec309 100644 --- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py @@ -21,18 +21,23 @@ import gzip import os import zlib -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.contrib.data.python.ops import readers from tensorflow.core.example import example_pb2 from tensorflow.core.example import feature_pb2 +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape from tensorflow.python.lib.io import python_io from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import io_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib from tensorflow.python.util import compat @@ -54,7 +59,7 @@ class TextLineDatasetTest(test.TestCase): for j in range(num_lines): contents.append(self._lineText(i, j)) # Always include a newline after the record unless it is - # at the end of the file, in which case we include it sometimes. + # at the end of the file, in which case we include it if j + 1 != num_lines or i == 0: contents.append(b"\r\n" if crlf else b"\n") contents = b"".join(contents) @@ -81,11 +86,11 @@ class TextLineDatasetTest(test.TestCase): num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = dataset_ops.TextLineDataset( + repeat_dataset = readers.TextLineDataset( filenames, compression_type=compression_type).repeat(num_epochs) batch_dataset = repeat_dataset.batch(batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() @@ -150,7 +155,7 @@ class TextLineDatasetTest(test.TestCase): def testTextLineDatasetBuffering(self): test_filenames = self._createFiles(2, 5, crlf=True) - repeat_dataset = dataset_ops.TextLineDataset(test_filenames, buffer_size=10) + repeat_dataset = readers.TextLineDataset(test_filenames, buffer_size=10) iterator = repeat_dataset.make_one_shot_iterator() with self.test_session() as sess: @@ -160,6 +165,277 @@ class TextLineDatasetTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, saver, sess): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _build_graph(self, + test_filenames, + compression_type=None, + build_saveable=True): + ds = readers.TextLineDataset( + test_filenames, compression_type=compression_type, buffer_size=10) + iterator = ds.make_initializable_iterator() + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _testReadWithBreaks(self, breaks, num_files=5, lines_per_file=5): + """Tests reading from input pipeline with regular breaks. + + At each break point the iterator state gets saved using Saver and reloaded + in a new Graph and session. + + Args: + breaks: List of counts of records after reading which iterator state is + checkpointed. Must to in non-decreasing order. + num_files: Total number of files. + lines_per_file: Total number of lines per file. + """ + compression_types = [None, "GZIP", "ZLIB"] + for compression_type in compression_types: + test_filenames = self._createFiles( + num_files, + lines_per_file, + crlf=True, + compression_type=compression_type) + + # Collect ground truth. + total_records = num_files * lines_per_file + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type=compression_type) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(total_records): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Simulate run with breaks. + actual_records = [] + next_record_index = 0 + load_from_ckpt = False + breaks.append(total_records) + for break_index in breaks: + with ops.Graph().as_default() as g: + if not load_from_ckpt: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type=compression_type) + else: + saver = self._import_meta_graph() + init_op, get_next = ops.get_collection("iterator_ops") + + with self.test_session(graph=g) as sess: + if not load_from_ckpt: + sess.run(init_op) + else: + self._restore(saver, sess) + while next_record_index != break_index: + actual_records.append(sess.run(get_next)) + next_record_index += 1 + if break_index == total_records: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self._save(saver, sess) + load_from_ckpt = True + self.assertEqual(actual_records, expected_records) + + def testSaveAtFileBoundary(self): + self._testReadWithBreaks([10]) + + def testSaveWithinFile(self): + self._testReadWithBreaks([12]) + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveRestoreIdempotence(self): + # Attempt to save an iterator immediately after it has been + # restored. + self._testReadWithBreaks([0, 0]) + self._testReadWithBreaks([10, 10]) + self._testReadWithBreaks([12, 12]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 4, 20]) + + def testRestoreExhaustedIterator(self): + num_files = 2 + lines_per_file = 5 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_files * lines_per_file): + sess.run(get_next) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self._save(saver, sess) + + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver = self._import_meta_graph() + self._restore(saver, sess) + _, get_next = ops.get_collection("iterator_ops") + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + def testInitThenRestore(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + sess.run(get_next) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + saver = self._import_meta_graph() + init_op, get_next = ops.get_collection("iterator_ops") + sess.run(init_op) + self._restore(saver, sess) + for _ in range(total_records - break_record): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testRestoreInModifiedGraph(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + sess.run(get_next) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type="GZIP") + self._restore(saver, sess) + for _ in range(total_records - break_record): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testRestoreInModifiedGraphThenInit(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + expected_records.append(sess.run(get_next)) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + # Test that calling the init_op overrides the restored iterator. The + # iterator for the old graph was build to read uncompressed files and + # would fail when trying to read the new files. + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + test_filenames = self._createFiles( + num_files, lines_per_file, crlf=True, compression_type="GZIP") + init_op, get_next, saver = self._build_graph( + test_filenames, compression_type="GZIP") + self._restore(saver, sess) + sess.run(init_op) + for _ in range(total_records): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + + def testDoNotRestoreIterator(self): + num_files = 5 + lines_per_file = 5 + total_records = num_files * lines_per_file + break_record = 8 + test_filenames = self._createFiles(num_files, lines_per_file, crlf=True) + + expected_records = [] + with ops.Graph().as_default() as g: + init_op, get_next, saver = self._build_graph(test_filenames) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_record): + expected_records.append(sess.run(get_next)) + self._save(saver, sess) + for _ in range(total_records - break_record): + expected_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + + actual_records = [] + with ops.Graph().as_default() as g: + with self.test_session(graph=g) as sess: + init_op, get_next, saver = self._build_graph( + test_filenames, build_saveable=False) + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next) + sess.run(init_op) + for _ in range(total_records): + actual_records.append(sess.run(get_next)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next) + self.assertEqual(actual_records, expected_records) + class FixedLengthRecordReaderTest(test.TestCase): @@ -192,12 +468,12 @@ class FixedLengthRecordReaderTest(test.TestCase): num_epochs = array_ops.placeholder(dtypes.int64, shape=[]) batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = (dataset_ops.FixedLengthRecordDataset( + repeat_dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) batch_dataset = repeat_dataset.batch(batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) init_op = iterator.make_initializer(repeat_dataset) init_batch_op = iterator.make_initializer(batch_dataset) get_next = iterator.get_next() @@ -256,7 +532,7 @@ class FixedLengthRecordReaderTest(test.TestCase): def testFixedLengthRecordDatasetBuffering(self): test_filenames = self._createFiles() - dataset = dataset_ops.FixedLengthRecordDataset( + dataset = readers.FixedLengthRecordDataset( test_filenames, self._record_bytes, self._header_bytes, @@ -271,20 +547,44 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(iterator.get_next()) + def _iterator_checkpoint_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _save_op(self, iterator_resource): + iterator_state_variant = gen_dataset_ops.serialize_iterator( + iterator_resource) + save_op = io_ops.write_file( + self._iterator_checkpoint_path(), + parsing_ops.serialize_tensor(iterator_state_variant)) + return save_op + + def _restore_op(self, iterator_resource): + iterator_state_variant = parsing_ops.parse_tensor( + io_ops.read_file(self._iterator_checkpoint_path()), dtypes.variant) + restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource, + iterator_state_variant) + return restore_op + def _build_iterator_graph(self, num_epochs): filenames = self._createFiles() - path = os.path.join(self.get_temp_dir(), "iterator") - dataset = (dataset_ops.FixedLengthRecordDataset( + dataset = (readers.FixedLengthRecordDataset( filenames, self._record_bytes, self._header_bytes, self._footer_bytes) .repeat(num_epochs)) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next_op = iterator.get_next() - save_op = gen_dataset_ops.save_iterator(iterator._iterator_resource, path) - restore_op = gen_dataset_ops.restore_iterator(iterator._iterator_resource, - path) + save_op = self._save_op(iterator._iterator_resource) + restore_op = self._restore_op(iterator._iterator_resource) return init_op, get_next_op, save_op, restore_op + def _restore_iterator(self): + output_types = dtypes.string + output_shapes = tensor_shape.scalar() + iterator = iterator_ops.Iterator.from_structure(output_types, output_shapes) + get_next = iterator.get_next() + restore_op = self._restore_op(iterator._iterator_resource) + return restore_op, get_next + def testSaveRestore(self): num_epochs = 10 epoch_break = 5 @@ -318,11 +618,164 @@ class FixedLengthRecordReaderTest(test.TestCase): with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testInitThenRestore(self): + # Note: Calling init_op before restore_op is redundant. This test just makes + # sure we do not fail if restore is called on an already initialized + # iterator resource. + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + with ops.Graph().as_default() as g: init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( num_epochs=num_epochs) with self.test_session(graph=g) as sess: sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreInModifiedGraph(self): + num_epochs = 10 + num_epochs_1 = 20 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs_1) + with self.test_session(graph=g) as sess: + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch < epoch_break or + (epoch == epoch_break and f < file_break) or + (epoch == epoch_break and f == file_break and + r < record_break)): + continue + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testRestoreWithoutBuildingDatasetGraph(self): + num_epochs = 10 + epoch_break = 5 + file_break = self._num_files // 2 + record_break = self._num_records // 2 + + with ops.Graph().as_default() as g: + init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( + num_epochs=num_epochs) + with self.test_session(graph=g) as sess: + sess.run(init_op) + # Note: There is no checkpoint saved currently so a NotFoundError is + # raised. + with self.assertRaises(errors.NotFoundError): + sess.run(restore_op) + for epoch in range(num_epochs): + for f in range(self._num_files): + for r in range(self._num_records): + if (epoch == epoch_break and f == file_break and + r == record_break): + sess.run(save_op) + break + self.assertEqual(self._record(f, r), sess.run(get_next_op)) + else: + continue + break + else: + continue + break + else: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + restore_op, get_next_op = self._restore_iterator() + with self.test_session(graph=g) as sess: sess.run(restore_op) for epoch in range(num_epochs): for f in range(self._num_files): @@ -353,7 +806,6 @@ class FixedLengthRecordReaderTest(test.TestCase): init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( num_epochs=num_epochs) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) for _ in range(num_epochs * self._num_files * self._num_records): sess.run(get_next_op) @@ -384,7 +836,6 @@ class FixedLengthRecordReaderTest(test.TestCase): init_op, get_next_op, save_op, restore_op = self._build_iterator_graph( num_epochs=num_epochs) with self.test_session(graph=g) as sess: - sess.run(init_op) sess.run(restore_op) with self.assertRaises(errors.OutOfRangeError): sess.run(get_next_op) @@ -405,11 +856,12 @@ class TFRecordDatasetTest(test.TestCase): self.compression_type = array_ops.placeholder_with_default("", shape=[]) self.batch_size = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = dataset_ops.TFRecordDataset( - self.filenames, self.compression_type).repeat(self.num_epochs) + repeat_dataset = readers.TFRecordDataset(self.filenames, + self.compression_type).repeat( + self.num_epochs) batch_dataset = repeat_dataset.batch(self.batch_size) - iterator = dataset_ops.Iterator.from_structure(batch_dataset.output_types) + iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types) self.init_op = iterator.make_initializer(repeat_dataset) self.init_batch_op = iterator.make_initializer(batch_dataset) self.get_next = iterator.get_next() @@ -539,8 +991,7 @@ class TFRecordDatasetTest(test.TestCase): def testReadWithBuffer(self): one_mebibyte = 2**20 - d = dataset_ops.TFRecordDataset( - self.test_filenames, buffer_size=one_mebibyte) + d = readers.TFRecordDataset(self.test_filenames, buffer_size=one_mebibyte) iterator = d.make_one_shot_iterator() with self.test_session() as sess: for j in range(self._num_files): @@ -563,7 +1014,7 @@ class ReadBatchFeaturesTest(test.TestCase): self.num_epochs = num_epochs self.batch_size = batch_size - return dataset_ops.read_batch_features( + return readers.read_batch_features( file_pattern=self.filenames, batch_size=self.batch_size, features={ @@ -571,7 +1022,7 @@ class ReadBatchFeaturesTest(test.TestCase): "record": parsing_ops.FixedLenFeature([], dtypes.int64), "keywords": parsing_ops.VarLenFeature(dtypes.string) }, - reader=dataset_ops.TFRecordDataset, + reader=readers.TFRecordDataset, randomize_input=False, num_epochs=self.num_epochs) @@ -715,7 +1166,7 @@ class ReadBatchFeaturesTest(test.TestCase): "file": parsing_ops.FixedLenFeature([], dtypes.int64), "record": parsing_ops.FixedLenFeature([], dtypes.int64), } - dataset = (dataset_ops.TFRecordDataset(self.test_filenames) + dataset = (readers.TFRecordDataset(self.test_filenames) .map(lambda x: parsing_ops.parse_single_example(x, features)) .repeat(10).batch(2)) iterator = dataset.make_initializable_iterator() diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/contrib/data/python/kernel_tests/resample_test.py index 79f9ba332f143191a165dacae8737e4d0829d0b9..0ac8d7359f7234d98167277724780bf31555e6fb 100644 --- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/resample_test.py @@ -20,12 +20,10 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import resampling from tensorflow.python.framework import errors -from tensorflow.python.framework import ops from tensorflow.python.ops import string_ops -from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import device_setter from tensorflow.python.util import compat @@ -41,20 +39,17 @@ class ResampleTest(test.TestCase): classes = np.random.randint(5, size=(20000,)) # Uniformly sampled target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] initial_dist = [0.2] * 5 if initial_known else None - iterator = (dataset_ops.Dataset.from_tensor_slices(classes) - .shuffle(200, seed=21) - .map(lambda c: (c, string_ops.as_string(c))) - .apply(dataset_ops.rejection_resample(target_dist=target_dist, - initial_dist=initial_dist, - class_func=lambda c, _: c, - seed=27)) - .make_initializable_iterator()) + iterator = (dataset_ops.Dataset.from_tensor_slices(classes).shuffle( + 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).apply( + resampling.rejection_resample( + target_dist=target_dist, + initial_dist=initial_dist, + class_func=lambda c, _: c, + seed=27)).make_initializable_iterator()) init_op = iterator.initializer get_next = iterator.get_next() - variable_init_op = variables.local_variables_initializer() with self.test_session() as sess: - sess.run(variable_init_op) sess.run(init_op) returned = [] with self.assertRaises(errors.OutOfRangeError): @@ -75,22 +70,6 @@ class ResampleTest(test.TestCase): returned_dist = class_counts / total_returned self.assertAllClose(target_dist, returned_dist, atol=1e-2) - def testVariableDevicePlacement(self): - classes = np.random.randint(5, size=(20000,)) # Uniformly sampled - target_dist = [0.9, 0.05, 0.05, 0.0, 0.0] - with ops.device( - device_setter.replica_device_setter(ps_tasks=1, ps_device="/cpu:0")): - _ = (dataset_ops.Dataset.from_tensor_slices(classes) - .shuffle(200, seed=21) - .map(lambda c: (c, string_ops.as_string(c))) - .apply(dataset_ops.rejection_resample( - target_dist=target_dist, initial_dist=None, - class_func=lambda c, _: c, seed=27))) - - self.assertEqual(1, len(variables.local_variables())) - self.assertEqual(b"", - compat.as_bytes(variables.local_variables()[0].device)) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5338ec56bf275e481a984964e39aa0c1ade3a752 --- /dev/null +++ b/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the experimental input pipeline ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools + +import numpy as np + +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class ScanDatasetTest(test.TestCase): + + def _count(self, start, step): + return dataset_ops.Dataset.from_tensors(0).repeat(None).apply( + scan_ops.scan(start, lambda state, _: (state + step, state))) + + def testCount(self): + start = array_ops.placeholder(dtypes.int32, shape=[]) + step = array_ops.placeholder(dtypes.int32, shape=[]) + take = array_ops.placeholder(dtypes.int64, shape=[]) + iterator = self._count(start, step).take(take).make_initializable_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + + for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10), + (10, 2, 10), (10, -1, 10), + (10, -2, 10)]: + sess.run(iterator.initializer, + feed_dict={start: start_val, step: step_val, take: take_val}) + for expected, _ in zip( + itertools.count(start_val, step_val), range(take_val)): + self.assertEqual(expected, sess.run(next_element)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testFibonacci(self): + iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply( + scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])) + ).make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(1, sess.run(next_element)) + self.assertEqual(2, sess.run(next_element)) + self.assertEqual(3, sess.run(next_element)) + self.assertEqual(5, sess.run(next_element)) + self.assertEqual(8, sess.run(next_element)) + + def testChangingStateShape(self): + # Test the fixed-point shape invariant calculations: start with + # initial values with known shapes, and use a scan function that + # changes the size of the state on each element. + def _scan_fn(state, input_value): + # Statically known rank, but dynamic length. + ret_longer_vector = array_ops.concat([state[0], state[0]], 0) + # Statically unknown rank. + ret_larger_rank = array_ops.expand_dims(state[1], 0) + return (ret_longer_vector, ret_larger_rank), (state, input_value) + + dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply( + scan_ops.scan(([0], 1), _scan_fn)) + self.assertEqual([None], dataset.output_shapes[0][0].as_list()) + self.assertIs(None, dataset.output_shapes[0][1].ndims) + self.assertEqual([], dataset.output_shapes[1].as_list()) + + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + + with self.test_session() as sess: + for i in range(5): + (longer_vector_val, larger_rank_val), _ = sess.run(next_element) + self.assertAllEqual([0] * (2**i), longer_vector_val) + self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val) + with self.assertRaises(errors.OutOfRangeError): + sess.run(next_element) + + def testIncorrectStateType(self): + + def _scan_fn(state, _): + return constant_op.constant(1, dtype=dtypes.int64), state + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The element types for the new state must match the initial state."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + def testIncorrectReturnType(self): + + def _scan_fn(unused_state, unused_input_value): + return constant_op.constant(1, dtype=dtypes.int64) + + dataset = dataset_ops.Dataset.range(10) + with self.assertRaisesRegexp( + TypeError, + "The scan function must return a pair comprising the new state and the " + "output value."): + dataset.apply( + scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py index d9bfca30bbfd66afead842c2bc3020e9d4bcc2d9..6b5b53cc0f8f2d1df5622a5bc5e2f8ef04c6342a 100644 --- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py @@ -18,15 +18,22 @@ from __future__ import division from __future__ import print_function import collections +import os import numpy as np -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops +from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.platform import gfile from tensorflow.python.platform import test +from tensorflow.python.training import saver as saver_lib class ShuffleDatasetTest(test.TestCase): @@ -41,8 +48,9 @@ class ShuffleDatasetTest(test.TestCase): buffer_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) seed_placeholder = array_ops.placeholder(dtypes.int64, shape=[]) - repeat_dataset = (dataset_ops.Dataset.from_tensor_slices(components) - .repeat(count_placeholder)) + repeat_dataset = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components) + .repeat(count_placeholder)) shuffle_dataset = repeat_dataset.shuffle(buffer_size_placeholder, seed_placeholder) @@ -52,7 +60,7 @@ class ShuffleDatasetTest(test.TestCase): # Create initialization ops for iterators without and with # shuffling, respectively. - iterator = dataset_ops.Iterator.from_structure( + iterator = iterator_ops.Iterator.from_structure( shuffle_dataset.output_types, shuffle_dataset.output_shapes) init_fifo_op = iterator.make_initializer(repeat_dataset) init_shuffle_op = iterator.make_initializer(shuffle_dataset) @@ -133,8 +141,9 @@ class ShuffleDatasetTest(test.TestCase): def testDefaultArguments(self): components = [0, 1, 2, 3, 4] - iterator = (dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) - .repeat().make_one_shot_iterator()) + iterator = ( + contrib_dataset_ops.Dataset.from_tensor_slices(components).shuffle(5) + .repeat().make_one_shot_iterator()) get_next = iterator.get_next() @@ -148,5 +157,322 @@ class ShuffleDatasetTest(test.TestCase): self.assertEqual(10, counts[i]) +class ShuffleDatasetSerializationTest(test.TestCase): + + def tearDown(self): + # Remove all checkpoint files. + prefix = self._ckpt_path() + pattern = prefix + "*" + files = gfile.Glob(pattern) + map(gfile.Remove, files) + + def _build_graph(self, + range_limit=10, + num_repeats=5, + buffer_size=5, + seed=None, + reshuffle_each_iteration=None, + build_saveable=True): + iterator = dataset_ops.Dataset.range(range_limit).shuffle( + buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration).repeat( + num_repeats).make_initializable_iterator() + if build_saveable: + saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator) + ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable) + init_op = iterator.initializer + get_next = iterator.get_next() + ops.add_to_collection("iterator_ops", init_op) + ops.add_to_collection("iterator_ops", get_next) + saver = saver_lib.Saver(allow_empty=True) + return init_op, get_next, saver + + def _ckpt_path(self): + return os.path.join(self.get_temp_dir(), "iterator") + + def _latest_ckpt(self): + return saver_lib.latest_checkpoint(self.get_temp_dir()) + + def _save(self, sess, saver): + saver.save(sess, self._ckpt_path()) + + def _restore(self, saver, sess): + saver.restore(sess, self._latest_ckpt()) + + def _import_meta_graph(self): + meta_file_path = self._ckpt_path() + ".meta" + return saver_lib.import_meta_graph(meta_file_path) + + def _testReadWithBreaks(self, break_points, init_before_restore=False): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + expected = [] + actual = [] + # Generate the ground truth. + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, _ = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Run and checkpoint after first break_point. + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_points[0]): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + + # Load from checkpoint and continue running while stopping at each + # subsequent checkpoint. + for i in range(len(break_points)): + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + if init_before_restore: + sess.run(init_op) + self._restore(saver, sess) + start = break_points[i] + end = break_points[ + i + 1] if i < len(break_points) - 1 else num_outputs + for _ in range(end - start): + actual.append(sess.run(get_next_op)) + self._save(sess, saver) + if end == num_outputs: + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self.assertEqual(expected, actual) + + def testSaveRestore(self): + self._testReadWithBreaks([8]) # rng buffer_size: 0 + self._testReadWithBreaks([13]) # rng buffer_size: 1 + self._testReadWithBreaks([18]) # rng buffer_size: 2 + self._testReadWithBreaks([23]) # rng buffer_size: 3 + + def testSaveUnusedIterator(self): + self._testReadWithBreaks([0]) + + def testSaveFullyUsedIterator(self): + self._testReadWithBreaks([50]) + + def testMultipleBreaks(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32]) + + def testIdempotence(self): + # Attempt to save iterator immediately after restoring. + self._testReadWithBreaks([1, 1, 5, 5, 5, 25, 32]) + + def testInitThenRestore(self): + self._testReadWithBreaks([0, 5, 9, 15, 25, 32], init_before_restore=True) + + def testRestoreExhaustedIterator(self): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + sess.run(get_next_op) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + def testResetRestoredIterator(self): + seed = 55 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [1, 3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs // 2): + sess.run(get_next_op) + self._save(sess, saver) + + outputs = [] + with ops.Graph().as_default() as g: + saver = self._import_meta_graph() + init_op, get_next_op = ops.get_collection("iterator_ops") + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + sess.run(init_op) + for _ in range(num_outputs): + outputs.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + expected_outputs_sorted = sorted( + np.array([range(range_limit) + for _ in range(num_repeats)]).flatten()) + self.assertEqual(expected_outputs_sorted, sorted(outputs)) + + def testRestoreInModifiedGraph(self): + seed = 55 + break_point = 25 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + expected = [] + actual_without_restore = [] + actual = [] + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + expected.append(sess.run(get_next_op)) + actual.extend(expected) + self._save(sess, saver) + for _ in range(num_outputs - break_point): + expected.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(num_outputs): + actual_without_restore.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + self._restore(saver, sess) + for _ in range(num_outputs - break_point): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + + # Since the modified graph has a different random seed it produces a + # different order of examples. + self.assertNotEqual(expected, actual_without_restore) + self.assertEqual(sorted(expected), sorted(actual_without_restore)) + self.assertEqual(expected, actual) + + def testDoNotBuildSaveable(self): + seed = 55 + break_point = 25 + range_limit = 10 + num_repeats = 5 + num_outputs = range_limit * num_repeats + buffer_sizes = [3, 8, 10, 25, 50] + reshuffle_each_iteration = False + for buffer_size in buffer_sizes: + actual = [] + with ops.Graph().as_default() as g: + g.seed = 10 + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration) + with self.test_session(graph=g) as sess: + sess.run(init_op) + for _ in range(break_point): + sess.run(get_next_op) + self._save(sess, saver) + + with ops.Graph().as_default() as g: + g.seed = 20 # Different seed than previous graph for shuffle rngs. + init_op, get_next_op, saver = self._build_graph( + range_limit=range_limit, + num_repeats=num_repeats, + buffer_size=buffer_size, + seed=seed, + reshuffle_each_iteration=reshuffle_each_iteration, + build_saveable=False) + with self.test_session(graph=g) as sess: + # Since the SaveableObject was not added to Saver's list + # of saveables, iterator state is not restored by saver.restore(). + self._restore(saver, sess) + with self.assertRaises(errors.FailedPreconditionError): + sess.run(get_next_op) + sess.run(init_op) + for _ in range(num_outputs): + actual.append(sess.run(get_next_op)) + with self.assertRaises(errors.OutOfRangeError): + sess.run(get_next_op) + expected_outputs_sorted = sorted( + np.array([range(range_limit) for _ in range(num_repeats)]).flatten()) + self.assertEqual(expected_outputs_sorted, sorted(actual)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py index b3de7795776a1e66d3b947874d3ee57ce95d59e1..efd864f866611bfd3bac1edcf98d84be852410fd 100644 --- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py +++ b/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py @@ -21,7 +21,7 @@ from __future__ import print_function import os import sqlite3 -from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.contrib.data.python.ops import readers from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.ops import array_ops @@ -31,9 +31,8 @@ from tensorflow.python.platform import test class SqlDatasetTest(test.TestCase): def _createSqlDataset(self, output_types, num_repeats=1): - dataset = dataset_ops.SqlDataset(self.driver_name, self.data_source_name, - self.query, - output_types).repeat(num_repeats) + dataset = readers.SqlDataset(self.driver_name, self.data_source_name, + self.query, output_types).repeat(num_repeats) iterator = dataset.make_initializable_iterator() init_op = iterator.initializer get_next = iterator.get_next() diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD index f429cc49de0ad11080098546ccee6033dc25595a..1b81cf5be9190ffab646192fb9a72fd3da7deee1 100644 --- a/tensorflow/contrib/data/python/ops/BUILD +++ b/tensorflow/contrib/data/python/ops/BUILD @@ -4,44 +4,127 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load( + "//tensorflow:tensorflow.bzl", + "tf_gen_op_wrapper_py", + "tf_kernel_library", +) +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") + py_library( name = "dataset_ops", - srcs = ["dataset_ops.py"], + srcs = [ + "dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":transformation_ops", + "//tensorflow/python:util", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/util:nest", + ], +) + +py_library( + name = "iterator_ops", + srcs = [ + "iterator_ops.py", + ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dataset_ops_gen", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], +) + +py_library( + name = "readers", + srcs = [ + "readers.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":dataset_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:function", - "//tensorflow/python:logging_ops", - "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:script_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", "//tensorflow/python/data/util:nest", - "//third_party/py/numpy", ], ) py_library( - name = "sloppy_ops", - srcs = ["sloppy_ops.py"], + name = "transformation_ops", + srcs = [ + "batching.py", + "enumerate_ops.py", + "error_ops.py", + "grouping.py", + "interleave_ops.py", + "resampling.py", + "scan_ops.py", + ], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:function", + "//tensorflow/python:logging_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python:tensor_util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/util:nest", + "//third_party/py/numpy", + ], +) + +tf_gen_op_wrapper_py( + name = "prefetching_ops", + out = "gen_prefetching_ops.py", + deps = ["//tensorflow/contrib/data:prefetching_ops_op_lib"], +) + +tf_kernel_library( + name = "prefetching_ops_kernels", + deps = [ + "//tensorflow/contrib/data/kernels:prefetching_kernels", + "//tensorflow/core:framework", + ], + alwayslink = 1, +) + +tf_custom_op_py_library( + name = "prefetching_py", + srcs = ["prefetching_ops.py"], + dso = ["//tensorflow/contrib/data:_prefetching_ops.so"], + kernels = [ + ":prefetching_ops_kernels", + "//tensorflow/contrib/data:prefetching_ops_op_lib", + ], + srcs_version = "PY2AND3", + deps = [ + ":prefetching_ops", + "//tensorflow/contrib/util:util_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", ], ) diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py new file mode 100644 index 0000000000000000000000000000000000000000..abc9212a87550745490b974d25a929a66287f785 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/batching.py @@ -0,0 +1,350 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Batching dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import math_ops + + +def dense_to_sparse_batch(batch_size, row_shape): + """A transformation that batches ragged elements into `tf.SparseTensor`s. + + Like `Dataset.padded_batch()`, this transformation combines multiple + consecutive elements of the dataset, which might have different + shapes, into a single element. The resulting element has three + components (`indices`, `values`, and `dense_shape`), which + comprise a `tf.SparseTensor` that represents the same data. The + `row_shape` represents the dense shape of each row in the + resulting `tf.SparseTensor`, to which the effective batch size is + prepended. For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } + + a.apply(tf.contrib.data.dense_to_sparse_batch(batch_size=2, row_shape=[6])) == + { + ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices + ['a', 'b', 'c', 'a', 'b'], # values + [2, 6]), # dense_shape + ([[0, 0], [0, 1], [0, 2], [0, 3]], + ['a', 'b', 'c', 'd'], + [1, 6]) + } + ``` + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the + number of consecutive elements of this dataset to combine in a + single batch. + row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like + object representing the equivalent dense shape of a row in the + resulting `tf.SparseTensor`. Each element of this dataset must + have the same rank as `row_shape`, and must have size less + than or equal to `row_shape` in each dimension. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return DenseToSparseBatchDataset(dataset, batch_size, row_shape) + + return _apply_fn + + +def unbatch(): + """A Transformation which splits the elements of a dataset. + + For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, + where `B` may vary from element to element, then for each element in + the dataset, the unbatched dataset will contain `B` consecutive elements + of shape `[a0, a1, ...]`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + + def unbatch_map(arg, *rest): + if rest: + return dataset_ops.Dataset.from_tensor_slices((arg,) + rest) + else: + return dataset_ops.Dataset.from_tensor_slices(arg) + + return dataset.flat_map(map_func=unbatch_map) + + return _apply_fn + + +def batch_and_drop_remainder(batch_size): + """A batching transformation that omits the final small batch (if present). + + Like @{tf.data.Dataset.batch}, this transformation combines + consecutive elements of this dataset into batches. However, if the batch + size does not evenly divide the input dataset size, this transformation will + drop the final smaller element. + + The following example illustrates the difference between this + transformation and `Dataset.batch()`: + + ```python + dataset = tf.data.Dataset.range(200) + batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) + print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) + ``` + + By contrast, `dataset.batch(128)` would yield a two-element dataset with + shapes `(128,)` and `(72,)`, so the batch dimension would not be statically + known. + + Args: + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply} + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + tensor_batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + + batched = dataset.batch(tensor_batch_size) + flattened = _RestructuredDataset(batched, + tuple(nest.flatten(batched.output_types))) + + def _predicate(*xs): + """Return `True` if this element is a full batch.""" + # Extract the dynamic batch size from the first component of the flattened + # batched element. + first_component = xs[0] + first_component_batch_size = array_ops.shape( + first_component, out_type=dtypes.int64)[0] + + return math_ops.equal(first_component_batch_size, tensor_batch_size) + + filtered = flattened.filter(_predicate) + + maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) + + def _set_first_dimension(shape): + return shape.merge_with( + tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) + + known_shapes = nest.map_structure(_set_first_dimension, + batched.output_shapes) + return _RestructuredDataset(filtered, batched.output_types, known_shapes) + + return _apply_fn + + +class DenseToSparseBatchDataset(dataset_ops.Dataset): + """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" + + def __init__(self, input_dataset, batch_size, row_shape): + """See `Dataset.dense_to_sparse_batch()` for more details.""" + super(DenseToSparseBatchDataset, self).__init__() + if not isinstance(input_dataset.output_types, dtypes.DType): + raise TypeError("DenseToSparseDataset requires an input whose elements " + "have a single component, whereas the input has %r." % + input_dataset.output_types) + self._input_dataset = input_dataset + self._batch_size = batch_size + # pylint: disable=protected-access + self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape) + # pylint: enable=protected-access + + def _as_variant_tensor(self): + return gen_dataset_ops.dense_to_sparse_batch_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._batch_size, + self._row_shape, + output_shapes=self.output_shapes, + output_types=self.output_types) + + @property + def output_shapes(self): + num_elements = tensor_shape.Dimension(None) + return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), + tensor_shape.vector(num_elements), + tensor_shape.vector(self._row_shape.shape[0] + 1)) + + @property + def output_types(self): + return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) + + +class _RestructuredDataset(dataset_ops.Dataset): + """An internal helper for changing the structure and shape of a dataset.""" + + def __init__(self, dataset, output_types, output_shapes=None): + """Creates a new dataset with the given output types and shapes. + + The given `dataset` must have a structure that is convertible: + * `dataset.output_types` must be the same as `output_types` module nesting. + * Each shape in `dataset.output_shapes` must be compatible with each shape + in `output_shapes` (if given). + + Note: This helper permits "unsafe casts" for shapes, equivalent to using + `tf.Tensor.set_shape()` where domain-specific knowledge is available. + + Args: + dataset: A `Dataset` object. + output_types: A nested structure of `tf.DType` objects. + output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. + If omitted, the shapes will be inherited from `dataset`. + + Raises: + ValueError: If either `output_types` or `output_shapes` is not compatible + with the structure of `dataset`. + """ + super(_RestructuredDataset, self).__init__() + self._dataset = dataset + + # Validate that the types are compatible. + output_types = nest.map_structure(dtypes.as_dtype, output_types) + flat_original_types = nest.flatten(dataset.output_types) + flat_new_types = nest.flatten(output_types) + if flat_original_types != flat_new_types: + raise ValueError( + "Dataset with output types %r cannot be restructured to have output " + "types %r" % (dataset.output_types, output_types)) + + self._output_types = output_types + + if output_shapes is None: + # Inherit shapes from the original `dataset`. + self._output_shapes = nest.pack_sequence_as(output_types, + nest.flatten( + dataset.output_shapes)) + else: + # Validate that the shapes are compatible. + nest.assert_same_structure(output_types, output_shapes) + flat_original_shapes = nest.flatten(dataset.output_shapes) + flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) + + for original_shape, new_shape in zip(flat_original_shapes, + flat_new_shapes): + if not original_shape.is_compatible_with(new_shape): + raise ValueError( + "Dataset with output shapes %r cannot be restructured to have " + "incompatible output shapes %r" % (dataset.output_shapes, + output_shapes)) + self._output_shapes = nest.map_structure_up_to( + output_types, tensor_shape.as_shape, output_shapes) + + def _as_variant_tensor(self): + return self._dataset._as_variant_tensor() # pylint: disable=protected-access + + @property + def output_types(self): + return self._output_types + + @property + def output_shapes(self): + return self._output_shapes + + +class _MapAndBatchDataset(dataset_ops.MapDataset): + """A `Dataset` that maps a function over a batch of elements.""" + + def __init__(self, input_dataset, map_func, batch_size, num_parallel_batches): + """See `Dataset.map()` for details.""" + super(_MapAndBatchDataset, self).__init__(input_dataset, map_func) + + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name="batch_size") + self._num_parallel_batches = ops.convert_to_tensor( + num_parallel_batches, dtype=dtypes.int64, name="num_parallel_batches") + + def _as_variant_tensor(self): + # pylint: disable=protected-access + input_resource = self._input_dataset._as_variant_tensor() + return gen_dataset_ops.map_and_batch_dataset( + input_resource, + self._map_func.captured_inputs, + f=self._map_func, + batch_size=self._batch_size, + num_parallel_batches=self._num_parallel_batches, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + # pylint: enable=protected-access + + @property + def output_shapes(self): + return nest.pack_sequence_as(self._output_shapes, [ + tensor_shape.vector(tensor_util.constant_value( + self._batch_size)).concatenate(s) + for s in nest.flatten(self._output_shapes) + ]) + + @property + def output_types(self): + return self._output_types + + +def map_and_batch(map_func, batch_size, num_parallel_batches=1): + """Fused implementation of `map` and `batch`. + + Maps `map_func` across `batch_size` consecutive elements of this dataset + and then combines them into a batch. Similarly to `batch_and_drop_remainder`, + if the batch size does not evenly divide the input dataset size, this + transformation will drop the final smaller element. + + + Functionally, it is equivalent to `map` followed by + `batch_and_drop_remainder`. However, by fusing the two transformations + together, the implementation can be more efficient. This transformation is a + stop gap solution for performance critical workloads. Once automatic input + pipeline optimization are implemented, the fusing of map and batch will not + need to be exposed at the API level and this method will be removed. + + Args: + map_func: A function mapping a nested structure of tensors to another + nested structure of tensors. + batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements of this dataset to combine in a single batch. + num_parallel_batches: A `tf.int64` scalar `tf.Tensor`, representing the + number of batches to create in parallel. On one hand, higher values can + help mitigate the effect of stragglers. On the other hand, higher values + can increasing contention if CPU is scarce. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return _MapAndBatchDataset(dataset, map_func, batch_size, + num_parallel_batches) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/dataset_ops.py b/tensorflow/contrib/data/python/ops/dataset_ops.py index 062bdf4e7164f5573b45e31bceec00a42ff830b7..45d6dbe7438957029b4d6b71e181cb1fc3596ecb 100644 --- a/tensorflow/contrib/data/python/ops/dataset_ops.py +++ b/tensorflow/contrib/data/python/ops/dataset_ops.py @@ -17,30 +17,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import enumerate_ops +from tensorflow.contrib.data.python.ops import error_ops +from tensorflow.contrib.data.python.ops import grouping from tensorflow.python.data.ops import dataset_ops -# pylint: disable=unused-import -from tensorflow.python.data.ops.dataset_ops import Iterator -# pylint: enable=unused-import from tensorflow.python.data.util import nest -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib -from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import gen_io_ops -from tensorflow.python.ops import logging_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import parsing_ops -from tensorflow.python.ops import random_ops -from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.ops import script_ops -from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation class Dataset(dataset_ops.Dataset): @@ -55,8 +41,12 @@ class Dataset(dataset_ops.Dataset): super(Dataset, self).__init__() self._dataset = dataset + @deprecation.deprecated(None, "Use `ds._as_variant_tensor()`.") def make_dataset_resource(self): - return self._dataset.make_dataset_resource() + return self._as_variant_tensor() + + def _as_variant_tensor(self): + return self._dataset._as_variant_tensor() # pylint: disable=protected-access @property def output_shapes(self): @@ -67,6 +57,7 @@ class Dataset(dataset_ops.Dataset): return self._dataset.output_types @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensors()`.") def from_tensors(tensors): """Creates a `Dataset` with a single element, comprising the given tensors. @@ -79,6 +70,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.TensorDataset(tensors)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.") def from_tensor_slices(tensors): """Creates a `Dataset` whose elements are slices of the given tensors. @@ -92,6 +84,8 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.TensorSliceDataset(tensors)) @staticmethod + @deprecation.deprecated(None, + "Use `tf.data.Dataset.from_sparse_tensor_slices()`.") def from_sparse_tensor_slices(sparse_tensor): """Splits each rank-N `tf.SparseTensor` in this dataset row-wise. @@ -104,6 +98,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.SparseTensorSliceDataset(sparse_tensor)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.from_generator()`.") def from_generator(generator, output_types, output_shapes=None): """Creates a `Dataset` whose elements are generated by `generator`. @@ -141,125 +136,11 @@ class Dataset(dataset_ops.Dataset): Returns: A `Dataset`. """ - if not callable(generator): - raise TypeError("`generator` must be callable.") - if output_shapes is None: - output_shapes = nest.map_structure( - lambda _: tensor_shape.TensorShape(None), output_types) - else: - output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - - flattened_types = nest.flatten(output_types) - flattened_shapes = nest.flatten(output_shapes) - - generator_state = dataset_ops.Dataset._GeneratorState(generator) - - def get_iterator_id_map_fn(unused_dummy): - """Creates a unique `iterator_id` for each pass over the dataset. - - The "iterator_id" disambiguates between multiple concurrently - existing iterators. - - Args: - unused_dummy: Ignored value. - - Returns: - A `tf.int64` tensor whose value uniquely identifies an iterator in - `generator_state`. - """ - return script_ops.py_func( - generator_state.get_next_id, [], dtypes.int64, stateful=True) - - def generator_map_fn(iterator_id_t): - """Generates the next element from iterator with ID `iterator_id_t`. - - We map this function across an infinite repetition of the - `iterator_id_t`, and raise `StopIteration` to terminate the iteration. - - Args: - iterator_id_t: A `tf.int64` tensor whose value uniquely identifies - the iterator in `generator_state` from which to generate an element. - - Returns: - A nested structure of tensors representing an element from the iterator. - """ - - def generator_py_func(iterator_id): - """A `py_func` that will be called to invoke the iterator.""" - try: - values = next(generator_state.get_iterator(iterator_id)) - except StopIteration: - generator_state.iterator_completed(iterator_id) - raise StopIteration("Iteration finished.") - - # Use the same _convert function from the py_func() implementation to - # convert the returned values to arrays early, so that we can inspect - # their values. - # pylint: disable=protected-access - ret_arrays = [ - script_ops.FuncRegistry._convert(ret) - for ret in nest.flatten_up_to(output_types, values) - ] - # pylint: enable=protected-access - - # Additional type and shape checking to ensure that the components - # of the generated element match the `output_types` and `output_shapes` - # arguments. - for (ret_array, expected_dtype, expected_shape) in zip( - ret_arrays, flattened_types, flattened_shapes): - if ret_array.dtype != expected_dtype.as_numpy_dtype: - raise TypeError( - "`generator` yielded an element of type %s where an element " - "of type %s was expected." % (ret_array.dtype, - expected_dtype.as_numpy_dtype)) - if not expected_shape.is_compatible_with(ret_array.shape): - raise ValueError( - "`generator` yielded an element of shape %s where an element " - "of shape %s was expected." % (ret_array.shape, expected_shape)) - - return ret_arrays - - flat_values = script_ops.py_func( - generator_py_func, [iterator_id_t], flattened_types, stateful=True) - - # The `py_func()` op drops the inferred shapes, so we add them back in - # here. - if output_shapes is not None: - for ret_t, shape in zip(flat_values, flattened_shapes): - ret_t.set_shape(shape) - - return nest.pack_sequence_as(output_types, flat_values) - - # This function associates each traversal of `generator` with a unique - # iterator ID. - def flat_map_fn(iterator_id_t): - # First, generate an infinite dataset containing the iterator ID repeated - # forever. - repeated_id = Dataset.from_tensors(iterator_id_t).repeat(None) - - # The `generator_map_fn` gets the next element from the iterator with the - # relevant ID, and raises StopIteration when that iterator contains no - # more elements. - return repeated_id.map(generator_map_fn) - - # A single-element dataset that, each time it is evaluated, contains a - # freshly-generated and unique (for the returned dataset) int64 - # ID that will be used to identify the appropriate Python state, which - # is encapsulated in `generator_state`, and captured in - # `get_iterator_id_map_fn`. - dummy = 0 - id_dataset = Dataset.from_tensors(dummy).map(get_iterator_id_map_fn) - - # A dataset that contains all of the elements generated by a - # single iterator created from `generator`, identified by the - # iterator ID contained in `id_dataset`. Lifting the iteration - # into a flat_map here enables multiple repetitions and/or nested - # versions of the returned dataset to be created, because it forces - # the generation of a new ID for each version. - return id_dataset.flat_map(flat_map_fn) + return Dataset(dataset_ops.Dataset.from_generator( + generator, output_types, output_shapes)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.range()`.") def range(*args): """Creates a `Dataset` of a step-separated range of values. @@ -289,6 +170,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.RangeDataset(*args)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.zip()`.") def zip(datasets): """Creates a `Dataset` by zipping together the given datasets. @@ -368,6 +250,7 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset_ops.PrefetchDataset(self._dataset, buffer_size)) @staticmethod + @deprecation.deprecated(None, "Use `tf.data.Dataset.list_files()`.") def list_files(file_pattern): """A dataset of all files matching a pattern. @@ -404,10 +287,12 @@ class Dataset(dataset_ops.Dataset): """ return Dataset(dataset_ops.RepeatDataset(self._dataset, count)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.enumerate_dataset())`.") def enumerate(self, start=0): """Deprecated: Use `Dataset.apply(tf.contrib.data.enumerate_dataset(..)`.""" - return self.apply(enumerate_dataset(start)) + return self.apply(enumerate_ops.enumerate_dataset(start)) def shuffle(self, buffer_size, seed=None): """Randomly shuffles the elements of this dataset. @@ -521,10 +406,12 @@ class Dataset(dataset_ops.Dataset): """ return Dataset(self._dataset.shard(num_shards, index)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.ignore_errors())`.") def ignore_errors(self): - """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors()`.""" + """Deprecated: Use `Dataset.apply(tf.contrib.data.ignore_errors())`.""" - return self.apply(ignore_errors()) + return self.apply(error_ops.ignore_errors()) def batch(self, batch_size): """Combines consecutive elements of this dataset into batches. @@ -569,16 +456,26 @@ class Dataset(dataset_ops.Dataset): dataset_ops.PaddedBatchDataset(self._dataset, batch_size, padded_shapes, padding_values)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.dense_to_sparse_batch())`.") def dense_to_sparse_batch(self, batch_size, row_shape): """Use: `Dataset.apply(tf.contrib.data.dense_to_sparse_batch(...))`.""" - return self.apply(dense_to_sparse_batch(batch_size, row_shape)) + return self.apply(batching.dense_to_sparse_batch(batch_size, row_shape)) + @deprecation.deprecated( + None, "Use `ds.apply(tf.contrib.data.group_by_window())`.") def group_by_window(self, key_func, reduce_func, window_size): """Deprecated: Use `Dataset.apply(tf.contrib.data.group_by_window(...))`.""" - return self.apply(group_by_window(key_func, reduce_func, window_size)) + return self.apply( + grouping.group_by_window(key_func, reduce_func, window_size)) + @deprecation.deprecated_args( + None, + "Replace `num_threads=T` with `num_parallel_calls=T`. Replace " + "`output_buffer_size=N` with `ds.prefetch(N)` on the returned dataset.", + "num_threads", "output_buffer_size") def map(self, map_func, num_threads=None, @@ -700,10 +597,11 @@ class Dataset(dataset_ops.Dataset): dataset_ops.InterleaveDataset(self._dataset, map_func, cycle_length, block_length)) + @deprecation.deprecated(None, "Use `ds.apply(tf.contrib.data.unbatch())`.") def unbatch(self): """Deprecated: Use `Dataset.apply(tf.contrib.data.unbatch()`.""" - return self.apply(unbatch()) + return self.apply(batching.unbatch()) def filter(self, predicate): """Filters this dataset according to `predicate`. @@ -746,935 +644,46 @@ class Dataset(dataset_ops.Dataset): return Dataset(dataset) -class TextLineDataset(Dataset): - """A `Dataset` comprising lines from one or more text files.""" - - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TextLineDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes - to buffer. A value of 0 results in the default buffering values chosen - based on the compression type. - """ - dataset = dataset_ops.TextLineDataset(filenames, compression_type, - buffer_size) - super(TextLineDataset, self).__init__(dataset) - - -class TFRecordDataset(Dataset): - """A `Dataset` comprising records from one or more TFRecord files.""" - - def __init__(self, filenames, compression_type=None, buffer_size=None): - """Creates a `TFRecordDataset`. - - Args: - filenames: A `tf.string` tensor containing one or more filenames. - compression_type: (Optional.) A `tf.string` scalar evaluating to one of - `""` (no compression), `"ZLIB"`, or `"GZIP"`. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes in the read buffer. 0 means no buffering. - """ - dataset = dataset_ops.TFRecordDataset(filenames, compression_type, - buffer_size) - super(TFRecordDataset, self).__init__(dataset) - - -class FixedLengthRecordDataset(Dataset): - """A `Dataset` of fixed-length records from one or more binary files.""" - - def __init__(self, - filenames, - record_bytes, - header_bytes=None, - footer_bytes=None, - buffer_size=None): - """Creates a `FixedLengthRecordDataset`. +def get_single_element(dataset): + """Returns the single element in `dataset` as a nested structure of tensors. - Args: - filenames: A `tf.string` tensor containing one or more filenames. - record_bytes: A `tf.int64` scalar representing the number of bytes in - each record. - header_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to skip at the start of a file. - footer_bytes: (Optional.) A `tf.int64` scalar representing the number of - bytes to ignore at the end of a file. - buffer_size: (Optional.) A `tf.int64` scalar representing the number of - bytes to buffer when reading. - """ - dataset = dataset_ops.FixedLengthRecordDataset( - filenames, record_bytes, header_bytes, footer_bytes, buffer_size) - super(FixedLengthRecordDataset, self).__init__(dataset) - - -def enumerate_dataset(start=0): - """A transformation that enumerate the elements of a dataset. - - It is Similar to python's `enumerate`. + This function enables you to use a @{tf.data.Dataset} in a stateless + "tensor-in tensor-out" expression, without creating a @{tf.data.Iterator}. + This can be useful when your preprocessing transformations are expressed + as a `Dataset`, and you want to use the transformation at serving time. For example: ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { 1, 2, 3 } - b = { (7, 8), (9, 10) } - - # The nested structure of the `datasets` argument determines the - # structure of elements in the resulting dataset. - a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) } - b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) } - ``` - - Args: - start: A `tf.int64` scalar `tf.Tensor`, representing the start - value for enumeration. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - """ - - def _apply_fn(dataset): - max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max - return Dataset.zip((Dataset.range(start, max_value), dataset)) - - return _apply_fn - - -def ignore_errors(): - """Creates a `Dataset` from another `Dataset` and silently ignores any errors. - - Use this transformation to produce a dataset that contains the same elements - as the input, but silently drops any elements that caused an error. For - example: - - ```python - dataset = tf.contrib.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) - - # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. - dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) - - # Using `ignore_errors()` will drop the element that causes an error. - dataset = - dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.2 } - ``` - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - """ - - def _apply_fn(dataset): - return IgnoreErrorsDataset(dataset) - - return _apply_fn - - -def dense_to_sparse_batch(batch_size, row_shape): - """A transformation that batches ragged elements into `tf.SparseTensor`s. - - Like `Dataset.padded_batch()`, this transformation combines multiple - consecutive elements of the dataset, which might have different - shapes, into a single element. The resulting element has three - components (`indices`, `values`, and `dense_shape`), which - comprise a `tf.SparseTensor` that represents the same data. The - `row_shape` represents the dense shape of each row in the - resulting `tf.SparseTensor`, to which the effective batch size is - prepended. For example: - - ```python - # NOTE: The following examples use `{ ... }` to represent the - # contents of a dataset. - a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } - - a.apply(tf.contrib.data.dense_to_sparse_batch(batch_size=2, row_shape=[6])) == - { - ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices - ['a', 'b', 'c', 'a', 'b'], # values - [2, 6]), # dense_shape - ([[2, 0], [2, 1], [2, 2], [2, 3]], - ['a', 'b', 'c', 'd'], - [1, 6]) - } - ``` - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the - number of consecutive elements of this dataset to combine in a - single batch. - row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like - object representing the equivalent dense shape of a row in the - resulting `tf.SparseTensor`. Each element of this dataset must - have the same rank as `row_shape`, and must have size less - than or equal to `row_shape` in each dimension. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - """ - - def _apply_fn(dataset): - return DenseToSparseBatchDataset(dataset, batch_size, row_shape) - - return _apply_fn - - -def unbatch(): - """A Transformation which splits the elements of a dataset. - - For example, if elements of the dataset are shaped `[B, a0, a1, ...]`, - where `B` may vary from element to element, then for each element in - the dataset, the unbatched dataset will contain `B` consecutive elements - of shape `[a0, a1, ...]`. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - """ - - def _apply_fn(dataset): - - def unbatch_map(arg, *rest): - if rest: - return Dataset.from_tensor_slices((arg,) + rest) - else: - return Dataset.from_tensor_slices(arg) - - return dataset.flat_map(map_func=unbatch_map) - - return _apply_fn - - -def rejection_resample(class_func, - target_dist, - initial_dist=None, - seed=None): - """A transformation that resamples a dataset to achieve a target distribution. - - **NOTE** Resampling is performed via rejection sampling; some fraction - of the input values will be dropped. - - Args: - class_func: A function mapping an element of the input dataset to a scalar - `tf.int32` tensor. Values should be in `[0, num_classes)`. - target_dist: A floating point type tensor, shaped `[num_classes]`. - initial_dist: (Optional.) A floating point type tensor, shaped - `[num_classes]`. If not provided, the true class distribution is - estimated live in a streaming fashion. - seed: (Optional.) Python integer seed for the resampler. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - """ - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - dist_estimation_batch_size = 32 - target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") - class_values_ds = dataset.map(class_func) - if initial_dist is not None: - initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") - acceptance_dist = _calculate_acceptance_probs( - initial_dist_t, target_dist_t) - initial_dist_ds = Dataset.from_tensors(initial_dist_t).repeat() - acceptance_dist_ds = Dataset.from_tensors(acceptance_dist).repeat() - else: - num_classes = (target_dist_t.shape[0].value or - array_ops.shape(target_dist_t)[0]) - smoothing_constant = 10 - # Disable device functions and colocation constraints so that the variable - # will be placed with the eventual DT_VARIANT dataset tensor. - with ops.colocate_with(None, ignore_existing=True): - num_examples_per_class_seen = resource_variable_ops.ResourceVariable( - initial_value=array_ops.fill([num_classes], - np.int64(smoothing_constant)), - trainable=False, - collections=[ops.GraphKeys.LOCAL_VARIABLES], - name="local_class_count", - dtype=dtypes.int64) - - def update_estimate_and_tile(c): - return array_ops.tile( - array_ops.expand_dims( - _estimate_data_distribution(c, num_examples_per_class_seen), 0), - [dist_estimation_batch_size, 1]) - - initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) - .map(update_estimate_and_tile).apply(unbatch())) - acceptance_dist_ds = initial_dist_ds.map( - lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) - - def maybe_warn_on_large_rejection(accept_dist, initial_dist): - proportion_rejected = math_ops.reduce_sum( - (1 - accept_dist) * initial_dist) - return control_flow_ops.cond( - math_ops.less(proportion_rejected, .5), - lambda: accept_dist, - lambda: logging_ops.Print( # pylint: disable=g-long-lambda - accept_dist, [proportion_rejected, initial_dist, accept_dist], - message="Proportion of examples rejected by sampler is high: ", - summarize=100, - first_n=10)) - - acceptance_dist_ds = (Dataset.zip((acceptance_dist_ds, initial_dist_ds)) - .map(maybe_warn_on_large_rejection)) - - current_probabilities_ds = Dataset.zip( - (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) - filtered_ds = ( - Dataset.zip((class_values_ds, current_probabilities_ds, dataset)) - .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) - return filtered_ds.map(lambda class_value, _, data: (class_value, data)) - - return _apply_fn - - -def _calculate_acceptance_probs(initial_probs, target_probs): - """Calculate the per-class acceptance rates. + input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE]) - Args: - initial_probs: The class probabilities of the data. - target_probs: The desired class proportion in minibatches. - Returns: - A list of the per-class acceptance probabilities. - - This method is based on solving the following analysis: - - Let F be the probability of a rejection (on any example). - Let p_i be the proportion of examples in the data in class i (init_probs) - Let a_i is the rate the rejection sampler should *accept* class i - Let t_i is the target proportion in the minibatches for class i (target_probs) + def preprocessing_fn(input_str): + # ... + return image, label - ``` - F = sum_i(p_i * (1-a_i)) - = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 - ``` + dataset = (tf.data.Dataset.from_tensor_slices(input_batch) + .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE) + .batch(BATCH_SIZE)) - An example with class `i` will be accepted if `k` rejections occur, then an - example with class `i` is seen by the rejector, and it is accepted. This can - be written as follows: - - ``` - t_i = sum_k=0^inf(F^k * p_i * a_i) - = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 - = p_i * a_i / sum_j(p_j * a_j) using F from above + image_batch, label_batch = tf.contrib.data.get_single_element(dataset) ``` - Note that the following constraints hold: - ``` - 0 <= p_i <= 1, sum_i(p_i) = 1 - 0 <= a_i <= 1 - 0 <= t_i <= 1, sum_i(t_i) = 1 - ``` - - - A solution for a_i in terms of the other variabes is the following: - ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` - """ - # Add tiny to initial_probs to avoid divide by zero. - denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) - ratio_l = target_probs / denom - - # Calculate list of acceptance probabilities. - max_ratio = math_ops.reduce_max(ratio_l) - return ratio_l / max_ratio - - -def _estimate_data_distribution(c, num_examples_per_class_seen): - """Estimate data distribution as labels are seen. - Args: - c: The class labels. Type `int32`, shape `[batch_size]`. - num_examples_per_class_seen: A `ResourceVariable` containing counts. - Type `int64`, shape `[num_classes]`. + dataset: A @{tf.data.Dataset} object containing a single element. Returns: - dist: The updated distribution. Type `float32`, shape `[num_classes]`. - """ - num_classes = num_examples_per_class_seen.get_shape()[0].value - # Update the class-count based on what labels are seen in - # batch. But do this asynchronously to avoid performing a - # cross-device round-trip. Just use the cached value. - num_examples_per_class_seen = num_examples_per_class_seen.assign_add( - math_ops.reduce_sum( - array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) - init_prob_estimate = math_ops.truediv( - num_examples_per_class_seen, - math_ops.reduce_sum(num_examples_per_class_seen)) - return math_ops.cast(init_prob_estimate, dtypes.float32) - - -class _VariantDataset(dataset_ops.Dataset): - """A Dataset wrapper for a tf.variant-typed function argument.""" - - def __init__(self, dataset_variant, output_types, output_shapes): - super(_VariantDataset, self).__init__() - self._dataset_variant = dataset_variant - self._output_types = output_types - self._output_shapes = output_shapes - - def make_dataset_resource(self): - return self._dataset_variant - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - -class DenseToSparseBatchDataset(dataset_ops.Dataset): - """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s.""" - - def __init__(self, input_dataset, batch_size, row_shape): - """See `Dataset.dense_to_sparse_batch()` for more details.""" - super(DenseToSparseBatchDataset, self).__init__() - if not isinstance(input_dataset.output_types, dtypes.DType): - raise TypeError("DenseToSparseDataset requires an input whose elements " - "have a single component, whereas the input has %r." % - input_dataset.output_types) - self._input_dataset = input_dataset - self._batch_size = batch_size - # pylint: disable=protected-access - self._row_shape = dataset_ops._partial_shape_to_tensor(row_shape) - # pylint: enable=protected-access - - def make_dataset_resource(self): - return gen_dataset_ops.dense_to_sparse_batch_dataset( - self._input_dataset.make_dataset_resource(), - self._batch_size, - self._row_shape, - output_shapes=self.output_shapes, - output_types=self.output_types) - - @property - def output_shapes(self): - num_elements = tensor_shape.Dimension(None) - return (tensor_shape.matrix(num_elements, self._row_shape.shape[0] + 1), - tensor_shape.vector(num_elements), - tensor_shape.vector(self._row_shape.shape[0] + 1)) - - @property - def output_types(self): - return (dtypes.int64, self._input_dataset.output_types, dtypes.int64) - - -class IgnoreErrorsDataset(dataset_ops.Dataset): - """A `Dataset` that silently ignores errors when computing its input.""" - - def __init__(self, input_dataset): - """See `Dataset.ignore_errors()` for details.""" - super(IgnoreErrorsDataset, self).__init__() - self._input_dataset = input_dataset - - def make_dataset_resource(self): - return gen_dataset_ops.ignore_errors_dataset( - self._input_dataset.make_dataset_resource(), - output_shapes=nest.flatten(self.output_shapes), - output_types=nest.flatten(self.output_types)) - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -def read_batch_features(file_pattern, - batch_size, - features, - reader, - reader_args=None, - randomize_input=True, - num_epochs=None, - capacity=10000): - """Reads batches of Examples. - - Example: - - ``` - serialized_examples = [ - features { - feature { key: "age" value { int64_list { value: [ 0 ] } } } - feature { key: "gender" value { bytes_list { value: [ "f" ] } } } - feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } - }, - features { - feature { key: "age" value { int64_list { value: [] } } } - feature { key: "gender" value { bytes_list { value: [ "f" ] } } } - feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } - } - ] - ``` - - We can use arguments: - - ``` - features: { - "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), - "gender": FixedLenFeature([], dtype=tf.string), - "kws": VarLenFeature(dtype=tf.string), - } - ``` - - And the expected output is: - - ```python - { - "age": [[0], [-1]], - "gender": [["f"], ["f"]], - "kws": SparseTensor( - indices=[[0, 0], [0, 1], [1, 0]], - values=["code", "art", "sports"] - dense_shape=[2, 2]), - } - ``` - - Args: - file_pattern: List of files or patterns of file paths containing - `Example` records. See `tf.gfile.Glob` for pattern rules. - batch_size: An int representing the number of consecutive elements of this - dataset to combine in a single batch. - features: A `dict` mapping feature keys to `FixedLenFeature` or - `VarLenFeature` values. See `tf.parse_example`. - reader: A function or class that can be called with a `filenames` tensor - and (optional) `reader_args` and returns a `Dataset` of serialized - Examples. - reader_args: Additional arguments to pass to the reader class. - randomize_input: Whether the input should be randomized. - num_epochs: Integer specifying the number of times to read through the - dataset. If None, cycles through the dataset forever. - capacity: Capacity of the ShuffleDataset. A large capacity ensures better - shuffling but would increase memory usage and startup time. - - Returns: - A dict from keys in features to Tensor or SparseTensor objects. - """ - filenames = _get_file_names(file_pattern, randomize_input) - if reader_args: - dataset = reader(filenames, *reader_args) - else: - dataset = reader(filenames) - if dataset.output_types == (dtypes.string, dtypes.string): - dataset = dataset.map(lambda unused_k, v: v) - elif dataset.output_types != dtypes.string: - raise TypeError("`reader` must be a dataset of `tf.string` values, " - "or `(tf.string, tf.string)` key-value pairs.") - if num_epochs != 1: - dataset = dataset.repeat(num_epochs) - if randomize_input: - dataset = dataset.shuffle(capacity) - dataset = dataset.batch(batch_size) - dataset = dataset.map(lambda x: _parse_example(x, features)) - iterator = dataset.make_one_shot_iterator() - outputs = iterator.get_next() - index = 0 - result = {} - for key in sorted(features.keys()): - feature = features[key] - if isinstance(feature, parsing_ops.FixedLenFeature): - result[key] = outputs[index] - index += 1 - else: - result[key] = sparse_tensor_lib.SparseTensor( - indices=outputs[index], - values=outputs[index + 1], - dense_shape=outputs[index + 2]) - index += 3 - return result - - -def _parse_example(serialized, features): - parsed = parsing_ops.parse_example(serialized, features) - result = [] - for key in sorted(features.keys()): - val = parsed[key] - if isinstance(val, sparse_tensor_lib.SparseTensor): - result.extend([val.indices, val.values, val.dense_shape]) - else: - result.append(val) - return tuple(result) - - -def _get_file_names(file_pattern, randomize_input): - """Parse list of file names from pattern, optionally shuffled. - - Args: - file_pattern: File glob pattern, or list of glob patterns. - randomize_input: Whether to shuffle the order of file names. - - Returns: - List of file names matching `file_pattern`. + A nested structure of @{tf.Tensor} objects, corresponding to the single + element of `dataset`. Raises: - ValueError: If `file_pattern` is empty, or pattern matches no files. + TypeError: if `dataset` is not a `tf.data.Dataset` object. + InvalidArgumentError (at runtime): if `dataset` does not contain exactly + one element. """ - if isinstance(file_pattern, list): - if not file_pattern: - raise ValueError("File pattern is empty.") - file_names = [] - for entry in file_pattern: - file_names.extend(gfile.Glob(entry)) - else: - file_names = list(gfile.Glob(file_pattern)) - - if not file_names: - raise ValueError("No files match %s." % file_pattern) - - # Sort files so it will be deterministic for unit tests. - if not randomize_input: - file_names = sorted(file_names) - return file_names - - -class GroupByWindowDataset(dataset_ops.Dataset): - """A `Dataset` that groups its input and performs a windowed reduction.""" - - def __init__(self, input_dataset, key_func, reduce_func, window_size_func): - """See `group_by_window()` for details.""" - super(GroupByWindowDataset, self).__init__() - - self._input_dataset = input_dataset - - self._make_key_func(key_func, input_dataset) - self._make_reduce_func(reduce_func, input_dataset) - self._make_window_size_func(window_size_func) - - def _make_window_size_func(self, window_size_func): - """Make wrapping Defun for window_size_func.""" - - @function.Defun(dtypes.int64) - def tf_window_size_func(key): - key.set_shape([]) - window_size = ops.convert_to_tensor( - window_size_func(key), dtype=dtypes.int64) - if window_size.dtype != dtypes.int64: - raise ValueError( - "`window_size_func` must return a single tf.int64 tensor.") - return window_size - - self._window_size_func = tf_window_size_func - self._window_size_func.add_to_graph(ops.get_default_graph()) - - def _make_key_func(self, key_func, input_dataset): - """Make wrapping Defun for key_func.""" - - @function.Defun(*nest.flatten(input_dataset.output_types)) - def tf_key_func(*args): - """A wrapper for Defun that facilitates shape inference.""" - # Pass in shape information from the input_dataset. - for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): - arg.set_shape(shape) - nested_args = nest.pack_sequence_as(input_dataset.output_types, args) - # pylint: disable=protected-access - if dataset_ops._should_unpack_args(nested_args): - ret = key_func(*nested_args) - # pylint: enable=protected-access - else: - ret = key_func(nested_args) - ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) - if ret.dtype != dtypes.int64: - raise ValueError("`key_func` must return a single tf.int64 tensor.") - return ret - - self._key_func = tf_key_func - self._key_func.add_to_graph(ops.get_default_graph()) - - def _make_reduce_func(self, reduce_func, input_dataset): - """Make wrapping Defun for reduce_func.""" - - @function.Defun(dtypes.int64, dtypes.variant) - def tf_reduce_func(key, window_dataset_variant): - """A wrapper for Defun that facilitates shape inference.""" - key.set_shape([]) - window_dataset = _VariantDataset(window_dataset_variant, - input_dataset.output_types, - input_dataset.output_shapes) - if not isinstance(window_dataset, dataset_ops.Dataset): - raise TypeError("`window_dataset` must return a `Dataset` object.") - output_dataset = reduce_func(key, window_dataset) - if not isinstance(output_dataset, dataset_ops.Dataset): - raise TypeError("`reduce_func` must return a `Dataset` object.") - self._output_types = output_dataset.output_types - self._output_shapes = output_dataset.output_shapes - return output_dataset.make_dataset_resource() - - self._reduce_func = tf_reduce_func - self._reduce_func.add_to_graph(ops.get_default_graph()) - - @property - def output_shapes(self): - return self._output_shapes - - @property - def output_types(self): - return self._output_types - - def make_dataset_resource(self): - return gen_dataset_ops.group_by_window_dataset( - self._input_dataset.make_dataset_resource(), - self._key_func.captured_inputs, - self._reduce_func.captured_inputs, - self._window_size_func.captured_inputs, - key_func=self._key_func, - reduce_func=self._reduce_func, - window_size_func=self._window_size_func, - output_types=nest.flatten(self.output_types), - output_shapes=nest.flatten(self.output_shapes)) - - -def group_by_window(key_func, - reduce_func, - window_size=None, - window_size_func=None): - """A transformation that groups windows of elements by key and reduces them. - - This transformation maps each consecutive element in a dataset to a key - using `key_func` and groups the elements by key. It then applies - `reduce_func` to at most `window_size_func(key)` elements matching the same - key. All execpt the final window for each key will contain - `window_size_func(key)` elements; the final window may be smaller. - - You may provide either a constant `window_size` or a window size determined by - the key through `window_size_func`. - - Args: - key_func: A function mapping a nested structure of tensors - (having shapes and types defined by `self.output_shapes` and - `self.output_types`) to a scalar `tf.int64` tensor. - reduce_func: A function mapping a key and a dataset of up to `batch_size` - consecutive elements matching that key to another dataset. - window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements matching the same key to combine in a single - batch, which will be passed to `reduce_func`. Mutually exclusive with - `window_size_func`. - window_size_func: A function mapping a key to a `tf.int64` scalar - `tf.Tensor`, representing the number of consecutive elements matching - the same key to combine in a single batch, which will be passed to - `reduce_func`. Mutually exclusive with `window_size`. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. - - Raises: - ValueError: if neither or both of {`window_size`, `window_size_func`} are - passed. - """ - if (window_size is not None and window_size_func or - not (window_size is not None or window_size_func)): - raise ValueError("Must pass either window_size or window_size_func.") - - if window_size is not None: - - def constant_window_func(unused_key): - return ops.convert_to_tensor(window_size, dtype=dtypes.int64) - - window_size_func = constant_window_func - - assert window_size_func is not None - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - return GroupByWindowDataset(dataset, key_func, reduce_func, - window_size_func) - - return _apply_fn - - -class SqlDataset(dataset_ops.Dataset): - """A `Dataset` consisting of the results from a SQL query.""" - - def __init__(self, driver_name, data_source_name, query, output_types): - """Creates a `SqlDataset`. - - `SqlDataset` allows a user to read data from the result set of a SQL query. - For example: - - ```python - dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", - "SELECT name, age FROM people", - (tf.string, tf.int32)) - iterator = dataset.make_one_shot_iterator() - next_element = iterator.get_next() - # Prints the rows of the result set of the above query. - while True: - try: - print(sess.run(next_element)) - except tf.errors.OutOfRangeError: - break - ``` - - Args: - driver_name: A 0-D `tf.string` tensor containing the database type. - Currently, the only supported value is 'sqlite'. - data_source_name: A 0-D `tf.string` tensor containing a connection string - to connect to the database. - query: A 0-D `tf.string` tensor containing the SQL query to execute. - output_types: A tuple of `tf.DType` objects representing the types of the - columns returned by `query`. - """ - super(SqlDataset, self).__init__() - self._driver_name = ops.convert_to_tensor( - driver_name, dtype=dtypes.string, name="driver_name") - self._data_source_name = ops.convert_to_tensor( - data_source_name, dtype=dtypes.string, name="data_source_name") - self._query = ops.convert_to_tensor( - query, dtype=dtypes.string, name="query") - self._output_types = output_types - - def make_dataset_resource(self): - return gen_dataset_ops.sql_dataset(self._driver_name, - self._data_source_name, self._query, - nest.flatten(self.output_types), - nest.flatten(self.output_shapes)) - - @property - def output_shapes(self): - return nest.map_structure(lambda _: tensor_shape.TensorShape([]), - self._output_types) - - @property - def output_types(self): - return self._output_types - - -class _RestructuredDataset(dataset_ops.Dataset): - """An internal helper for changing the structure and shape of a dataset.""" - - def __init__(self, dataset, output_types, output_shapes=None): - """Creates a new dataset with the given output types and shapes. - - The given `dataset` must have a structure that is convertible: - * `dataset.output_types` must be the same as `output_types` module nesting. - * Each shape in `dataset.output_shapes` must be compatible with each shape - in `output_shapes` (if given). - - Note: This helper permits "unsafe casts" for shapes, equivalent to using - `tf.Tensor.set_shape()` where domain-specific knowledge is available. - - Args: - dataset: A `Dataset` object. - output_types: A nested structure of `tf.DType` objects. - output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects. - If omitted, the shapes will be inherited from `dataset`. - - Raises: - ValueError: If either `output_types` or `output_shapes` is not compatible - with the structure of `dataset`. - """ - super(_RestructuredDataset, self).__init__() - self._dataset = dataset - - # Validate that the types are compatible. - output_types = nest.map_structure(dtypes.as_dtype, output_types) - flat_original_types = nest.flatten(dataset.output_types) - flat_new_types = nest.flatten(output_types) - if flat_original_types != flat_new_types: - raise ValueError( - "Dataset with output types %r cannot be restructured to have output " - "types %r" % (dataset.output_types, output_types)) - - self._output_types = output_types - - if output_shapes is None: - # Inherit shapes from the original `dataset`. - self._output_shapes = nest.pack_sequence_as( - output_types, nest.flatten(dataset.output_shapes)) - else: - # Validate that the shapes are compatible. - nest.assert_same_structure(output_types, output_shapes) - flat_original_shapes = nest.flatten(dataset.output_shapes) - flat_new_shapes = nest.flatten_up_to(output_types, output_shapes) - - for original_shape, new_shape in zip(flat_original_shapes, - flat_new_shapes): - if not original_shape.is_compatible_with(new_shape): - raise ValueError( - "Dataset with output shapes %r cannot be restructured to have " - "incompatible output shapes %r" - % (dataset.output_shapes, output_shapes)) - self._output_shapes = nest.map_structure_up_to( - output_types, tensor_shape.as_shape, output_shapes) - - def make_dataset_resource(self): - return self._dataset.make_dataset_resource() - - @property - def output_types(self): - return self._output_types - - @property - def output_shapes(self): - return self._output_shapes - - -def batch_and_drop_remainder(batch_size): - """A batching transformation that omits the final small batch (if present). - - Like @{tf.contrib.data.Dataset.batch}, this transformation combines - consecutive elements of this dataset into batches. However, if the batch - size does not evenly divide the input dataset size, this transformation will - drop the final smaller element. - - The following example illustrates the difference between this - transformation and `Dataset.batch()`: - - ```python - dataset = tf.contrib.data.Dataset.range(200) - batched = dataset.apply(tf.contrib.data.batch_and_drop_remainder(128)) - print(batched.output_shapes) # ==> "(128,)" (the batch dimension is known) - ``` - - By contrast, `dataset.batch(128)` would yield a two-element dataset with - shapes `(128,)` and `(72,)`, so the batch dimension would not be statically - known. - - Args: - batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of - consecutive elements of this dataset to combine in a single batch. - - Returns: - A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply} - """ - - def _apply_fn(dataset): - """Function from `Dataset` to `Dataset` that applies the transformation.""" - tensor_batch_size = ops.convert_to_tensor( - batch_size, dtype=dtypes.int64, name="batch_size") - - batched = dataset.batch(tensor_batch_size) - flattened = _RestructuredDataset(batched, - tuple(nest.flatten(batched.output_types))) - - def _predicate(*xs): - """Return `True` if this element is a full batch.""" - # Extract the dynamic batch size from the first component of the flattened - # batched element. - first_component = xs[0] - first_component_batch_size = array_ops.shape( - first_component, out_type=dtypes.int64)[0] - - return math_ops.equal(first_component_batch_size, tensor_batch_size) - - filtered = flattened.filter(_predicate) - - maybe_constant_batch_size = tensor_util.constant_value(tensor_batch_size) - - def _set_first_dimension(shape): - return shape.merge_with( - tensor_shape.vector(maybe_constant_batch_size).concatenate(shape[1:])) - - known_shapes = nest.map_structure(_set_first_dimension, - batched.output_shapes) - return _RestructuredDataset(filtered, batched.output_types, known_shapes) - - return _apply_fn + if not isinstance(dataset, dataset_ops.Dataset): + raise TypeError("`dataset` must be a `tf.data.Dataset` object.") + return nest.pack_sequence_as( + dataset.output_types, + gen_dataset_ops.dataset_to_single_element( + dataset._as_variant_tensor(), # pylint: disable=protected-access + output_types=nest.flatten(dataset.output_types), + output_shapes=nest.flatten(dataset.output_shapes))) diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ac2b386b81532b801139baa00fd5edd4ecd6ef0a --- /dev/null +++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Enumerate dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes + + +def enumerate_dataset(start=0): + """A transformation that enumerate the elements of a dataset. + + It is Similar to python's `enumerate`. + For example: + + ```python + # NOTE: The following examples use `{ ... }` to represent the + # contents of a dataset. + a = { 1, 2, 3 } + b = { (7, 8), (9, 10) } + + # The nested structure of the `datasets` argument determines the + # structure of elements in the resulting dataset. + a.apply(tf.contrib.data.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) } + b.apply(tf.contrib.data.enumerate()) == { (0, (7, 8)), (1, (9, 10)) } + ``` + + Args: + start: A `tf.int64` scalar `tf.Tensor`, representing the start + value for enumeration. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max + return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value), + dataset)) + + return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..238bb52b0205f9ab66f479f1b92e72ab6e38725b --- /dev/null +++ b/tensorflow/contrib/data/python/ops/error_ops.py @@ -0,0 +1,74 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ignore_errors dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.ops import gen_dataset_ops + + +def ignore_errors(): + """Creates a `Dataset` from another `Dataset` and silently ignores any errors. + + Use this transformation to produce a dataset that contains the same elements + as the input, but silently drops any elements that caused an error. For + example: + + ```python + dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.]) + + # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError. + dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error")) + + # Using `ignore_errors()` will drop the element that causes an error. + dataset = + dataset.apply(tf.contrib.data.ignore_errors()) # ==> { 1., 0.5, 0.2 } + ``` + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + return IgnoreErrorsDataset(dataset) + + return _apply_fn + + +class IgnoreErrorsDataset(dataset_ops.Dataset): + """A `Dataset` that silently ignores errors when computing its input.""" + + def __init__(self, input_dataset): + """See `Dataset.ignore_errors()` for details.""" + super(IgnoreErrorsDataset, self).__init__() + self._input_dataset = input_dataset + + def _as_variant_tensor(self): + return gen_dataset_ops.ignore_errors_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + output_shapes=nest.flatten(self.output_shapes), + output_types=nest.flatten(self.output_types)) + + @property + def output_shapes(self): + return self._input_dataset.output_shapes + + @property + def output_types(self): + return self._input_dataset.output_types diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py new file mode 100644 index 0000000000000000000000000000000000000000..6df7b22fb69bb14c41a26bd630a825442f67ee23 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/grouping.py @@ -0,0 +1,201 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Grouping dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +def group_by_window(key_func, + reduce_func, + window_size=None, + window_size_func=None): + """A transformation that groups windows of elements by key and reduces them. + + This transformation maps each consecutive element in a dataset to a key + using `key_func` and groups the elements by key. It then applies + `reduce_func` to at most `window_size_func(key)` elements matching the same + key. All execpt the final window for each key will contain + `window_size_func(key)` elements; the final window may be smaller. + + You may provide either a constant `window_size` or a window size determined by + the key through `window_size_func`. + + Args: + key_func: A function mapping a nested structure of tensors + (having shapes and types defined by `self.output_shapes` and + `self.output_types`) to a scalar `tf.int64` tensor. + reduce_func: A function mapping a key and a dataset of up to `batch_size` + consecutive elements matching that key to another dataset. + window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of + consecutive elements matching the same key to combine in a single + batch, which will be passed to `reduce_func`. Mutually exclusive with + `window_size_func`. + window_size_func: A function mapping a key to a `tf.int64` scalar + `tf.Tensor`, representing the number of consecutive elements matching + the same key to combine in a single batch, which will be passed to + `reduce_func`. Mutually exclusive with `window_size`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + + Raises: + ValueError: if neither or both of {`window_size`, `window_size_func`} are + passed. + """ + if (window_size is not None and window_size_func or + not (window_size is not None or window_size_func)): + raise ValueError("Must pass either window_size or window_size_func.") + + if window_size is not None: + + def constant_window_func(unused_key): + return ops.convert_to_tensor(window_size, dtype=dtypes.int64) + + window_size_func = constant_window_func + + assert window_size_func is not None + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + return GroupByWindowDataset(dataset, key_func, reduce_func, + window_size_func) + + return _apply_fn + + +class _VariantDataset(dataset_ops.Dataset): + """A Dataset wrapper for a tf.variant-typed function argument.""" + + def __init__(self, dataset_variant, output_types, output_shapes): + super(_VariantDataset, self).__init__() + self._dataset_variant = dataset_variant + self._output_types = output_types + self._output_shapes = output_shapes + + def _as_variant_tensor(self): + return self._dataset_variant + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +class GroupByWindowDataset(dataset_ops.Dataset): + """A `Dataset` that groups its input and performs a windowed reduction.""" + + def __init__(self, input_dataset, key_func, reduce_func, window_size_func): + """See `group_by_window()` for details.""" + super(GroupByWindowDataset, self).__init__() + + self._input_dataset = input_dataset + + self._make_key_func(key_func, input_dataset) + self._make_reduce_func(reduce_func, input_dataset) + self._make_window_size_func(window_size_func) + + def _make_window_size_func(self, window_size_func): + """Make wrapping Defun for window_size_func.""" + + @function.Defun(dtypes.int64) + def tf_window_size_func(key): + key.set_shape([]) + window_size = ops.convert_to_tensor( + window_size_func(key), dtype=dtypes.int64) + if window_size.dtype != dtypes.int64: + raise ValueError( + "`window_size_func` must return a single tf.int64 tensor.") + return window_size + + self._window_size_func = tf_window_size_func + self._window_size_func.add_to_graph(ops.get_default_graph()) + + def _make_key_func(self, key_func, input_dataset): + """Make wrapping Defun for key_func.""" + + @function.Defun(*nest.flatten(input_dataset.output_types)) + def tf_key_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the input_dataset. + for arg, shape in zip(args, nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + nested_args = nest.pack_sequence_as(input_dataset.output_types, args) + # pylint: disable=protected-access + if dataset_ops._should_unpack_args(nested_args): + ret = key_func(*nested_args) + # pylint: enable=protected-access + else: + ret = key_func(nested_args) + ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) + if ret.dtype != dtypes.int64: + raise ValueError("`key_func` must return a single tf.int64 tensor.") + return ret + + self._key_func = tf_key_func + self._key_func.add_to_graph(ops.get_default_graph()) + + def _make_reduce_func(self, reduce_func, input_dataset): + """Make wrapping Defun for reduce_func.""" + + @function.Defun(dtypes.int64, dtypes.variant) + def tf_reduce_func(key, window_dataset_variant): + """A wrapper for Defun that facilitates shape inference.""" + key.set_shape([]) + window_dataset = _VariantDataset(window_dataset_variant, + input_dataset.output_types, + input_dataset.output_shapes) + if not isinstance(window_dataset, dataset_ops.Dataset): + raise TypeError("`window_dataset` must return a `Dataset` object.") + output_dataset = reduce_func(key, window_dataset) + if not isinstance(output_dataset, dataset_ops.Dataset): + raise TypeError("`reduce_func` must return a `Dataset` object.") + self._output_types = output_dataset.output_types + self._output_shapes = output_dataset.output_shapes + return output_dataset._as_variant_tensor() # pylint: disable=protected-access + + self._reduce_func = tf_reduce_func + self._reduce_func.add_to_graph(ops.get_default_graph()) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.group_by_window_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access + self._key_func.captured_inputs, + self._reduce_func.captured_inputs, + self._window_size_func.captured_inputs, + key_func=self._key_func, + reduce_func=self._reduce_func, + window_size_func=self._window_size_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) diff --git a/tensorflow/contrib/data/python/ops/sloppy_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py similarity index 59% rename from tensorflow/contrib/data/python/ops/sloppy_ops.py rename to tensorflow/contrib/data/python/ops/interleave_ops.py index 375f54193c634424aad78d10cdb7b807c42fd9ba..74a919c1fff62cfa79b0877a3d081077ca6776f0 100644 --- a/tensorflow/contrib/data/python/ops/sloppy_ops.py +++ b/tensorflow/contrib/data/python/ops/interleave_ops.py @@ -23,14 +23,16 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import function from tensorflow.python.framework import ops from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.util import deprecation -class SloppyInterleaveDataset(dataset_ops.Dataset): +class ParallelInterleaveDataset(dataset_ops.Dataset): """A `Dataset` that maps a function over its input and flattens the result.""" - def __init__(self, input_dataset, map_func, cycle_length, block_length): - """See `tf.contrib.data.sloppy_interleave()` for details.""" - super(SloppyInterleaveDataset, self).__init__() + def __init__(self, input_dataset, map_func, cycle_length, block_length, + sloppy): + """See `tf.contrib.data.parallel_interleave()` for details.""" + super(ParallelInterleaveDataset, self).__init__() self._input_dataset = input_dataset @function.Defun(*nest.flatten(input_dataset.output_types)) @@ -53,7 +55,7 @@ class SloppyInterleaveDataset(dataset_ops.Dataset): self._output_types = dataset.output_types self._output_shapes = dataset.output_shapes - return dataset.make_dataset_resource() + return dataset._as_variant_tensor() # pylint: disable=protected-access self._map_func = tf_map_func self._map_func.add_to_graph(ops.get_default_graph()) @@ -62,13 +64,16 @@ class SloppyInterleaveDataset(dataset_ops.Dataset): cycle_length, dtype=dtypes.int64, name="cycle_length") self._block_length = ops.convert_to_tensor( block_length, dtype=dtypes.int64, name="block_length") + self._sloppy = ops.convert_to_tensor( + sloppy, dtype=dtypes.bool, name="sloppy") - def make_dataset_resource(self): - return gen_dataset_ops.sloppy_interleave_dataset( - self._input_dataset.make_dataset_resource(), + def _as_variant_tensor(self): + return gen_dataset_ops.parallel_interleave_dataset( + self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access self._map_func.captured_inputs, self._cycle_length, self._block_length, + self._sloppy, f=self._map_func, output_types=nest.flatten(self.output_types), output_shapes=nest.flatten(self.output_shapes)) @@ -82,7 +87,54 @@ class SloppyInterleaveDataset(dataset_ops.Dataset): return self._output_types -def sloppy_interleave(map_func, cycle_length, block_length): +def parallel_interleave(map_func, cycle_length, block_length=1, sloppy=False): + """A parallel version of the `Dataset.interleave()` transformation. + + `parallel_interleave()` maps `map_func` across its input to produce nested + datasets, and outputs their elements interleaved. Unlike + @{tf.data.Dataset.interleave}, it gets elements from `cycle_length` nested + datasets in parallel, which increases the throughput, especially in the + presence of stragglers. Furthermore, the `sloppy` argument can be used to + improve performance, by relaxing the requirement that the outputs are produced + in a deterministic order, and allowing the implementation to skip over nested + datasets whose elements are not readily available when requested. + + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.contrib.data.parallel_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + + WARNING: If `sloppy` is `True`, the order of produced elements is not + deterministic. + + Args: + map_func: A function mapping a nested structure of tensors to a `Dataset`. + cycle_length: The number of threads to interleave from in parallel. + block_length: The number of consecutive elements to pull from a thread + before advancing to the next thread. + sloppy: If false, elements are produced in deterministic order. Otherwise, + the implementation is allowed, for the sake of expediency, to produce + elements in a non-deterministic order. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return ParallelInterleaveDataset( + dataset, map_func, cycle_length, block_length, sloppy) + return _apply_fn + + +@deprecation.deprecated( + None, "Use `tf.contrib.data.parallel_interleave(..., sloppy=True)`.") +def sloppy_interleave(map_func, cycle_length, block_length=1): """A non-deterministic version of the `Dataset.interleave()` transformation. `sloppy_interleave()` maps `map_func` across `dataset`, and @@ -102,6 +154,17 @@ def sloppy_interleave(map_func, cycle_length, block_length): strictly obeys), producing an element from a different underlying dataset instead. + Example usage: + + ```python + # Preprocess 4 files concurrently. + filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords") + dataset = filenames.apply( + tf.contrib.data.sloppy_interleave( + lambda filename: tf.data.TFRecordDataset(filename), + cycle_length=4)) + ``` + WARNING: The order of elements in the resulting dataset is not deterministic. Use `Dataset.interleave()` if you want the elements to have a deterministic order. @@ -118,9 +181,9 @@ def sloppy_interleave(map_func, cycle_length, block_length): Returns: A `Dataset` transformation function, which can be passed to - @{tf.contrib.data.Dataset.apply}. + @{tf.data.Dataset.apply}. """ def _apply_fn(dataset): - return SloppyInterleaveDataset( - dataset, map_func, cycle_length, block_length) + return ParallelInterleaveDataset( + dataset, map_func, cycle_length, block_length, sloppy=True) return _apply_fn diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d736029fb035e573b70e8b19570e4e8ceca3c005 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/iterator_ops.py @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Iterator ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.training import saver + + +def make_saveable_from_iterator(iterator): + """Returns a SaveableObject for saving/restore iterator state using Saver. + + Args: + iterator: Iterator. + + For example: + + ```python + with tf.Graph().as_default(): + ds = tf.data.Dataset.range(10) + iterator = ds.make_initializable_iterator() + # Build the iterator SaveableObject. + saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator) + # Add the SaveableObject to the SAVEABLE_OBJECTS collection so + # it can be automatically saved using Saver. + tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj) + saver = tf.train.Saver() + + while continue_training: + ... Perform training ... + if should_save_checkpoint: + saver.save() + ``` + + Note: When restoring the iterator, the existing iterator state is completely + discarded. This means that any changes you may have made to the Dataset + graph will be discarded as well! This includes the new Dataset graph + that you may have built during validation. So, while running validation, + make sure to run the initializer for the validation input pipeline after + restoring the checkpoint. + + Note: Not all iterators support checkpointing yet. Attempting to save the + state of an unsupported iterator will throw an error. + """ + return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access + + +class _Saveable(saver.BaseSaverBuilder.SaveableObject): + """SaveableObject for saving/restoring iterator state.""" + + def __init__(self, iterator_resource): + serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource) + specs = [ + saver.BaseSaverBuilder.SaveSpec(serialized_iterator, "", + iterator_resource.name + "-state") + ] + super(_Saveable, self).__init__(iterator_resource, specs, + iterator_resource.name) + + def restore(self, restored_tensors, unused_restored_shapes): + with ops.colocate_with(self.op): + return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0]) diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..cfe8012b5657995b78d701528ea35cbb3748adb9 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py @@ -0,0 +1,55 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrapper for prefetching_ops.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import gen_prefetching_ops +from tensorflow.contrib.util import loader +from tensorflow.python.platform import resource_loader + +_prefetching_ops = loader.load_op_library( + resource_loader.get_path_to_datafile("../../_prefetching_ops.so")) + + +# TODO(rohanj): Add a python class that constructs resource in the __init__ +# method and provides a get_next() that calls the prefetch op. +def function_buffering_resource(string_arg, + target_device, + shared_name, + f, + buffer_size, + thread_pool_size=1, + container="", + name=None): + return gen_prefetching_ops.function_buffering_resource( + string_arg=string_arg, + target_device=target_device, + shared_name=shared_name, + f=f, + buffer_size=buffer_size, + thread_pool_size=thread_pool_size, + container=container, + name=name) + + +def function_buffering_resource_get_next(function_buffer_resource, + output_types, + name=None): + return gen_prefetching_ops.function_buffering_resource_get_next( + function_buffer_resource=function_buffer_resource, + output_types=output_types, + name=name) diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py new file mode 100644 index 0000000000000000000000000000000000000000..2e1c3153ca78e20e2628e8754b9827b817f8c732 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/readers.py @@ -0,0 +1,309 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python wrappers for reader Datasets.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.data.util import nest +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_dataset_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import gfile +from tensorflow.python.util import deprecation + + +class TextLineDataset(contrib_dataset_ops.Dataset): + """A `Dataset` comprising lines from one or more text files.""" + + @deprecation.deprecated(None, "Use `tf.data.TextLineDataset`.") + def __init__(self, filenames, compression_type=None, buffer_size=None): + """Creates a `TextLineDataset`. + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. + buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes + to buffer. A value of 0 results in the default buffering values chosen + based on the compression type. + """ + dataset = readers.TextLineDataset(filenames, compression_type, + buffer_size) + super(TextLineDataset, self).__init__(dataset) + + +class TFRecordDataset(contrib_dataset_ops.Dataset): + """A `Dataset` comprising records from one or more TFRecord files.""" + + @deprecation.deprecated(None, "Use `tf.data.TFRecordDataset`.") + def __init__(self, filenames, compression_type=None, buffer_size=None): + """Creates a `TFRecordDataset`. + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + compression_type: (Optional.) A `tf.string` scalar evaluating to one of + `""` (no compression), `"ZLIB"`, or `"GZIP"`. + buffer_size: (Optional.) A `tf.int64` scalar representing the number of + bytes in the read buffer. 0 means no buffering. + """ + dataset = readers.TFRecordDataset(filenames, compression_type, + buffer_size) + super(TFRecordDataset, self).__init__(dataset) + + +class FixedLengthRecordDataset(contrib_dataset_ops.Dataset): + """A `Dataset` of fixed-length records from one or more binary files.""" + + @deprecation.deprecated(None, "Use `tf.data.FixedLengthRecordDataset`.") + def __init__(self, + filenames, + record_bytes, + header_bytes=None, + footer_bytes=None, + buffer_size=None): + """Creates a `FixedLengthRecordDataset`. + + Args: + filenames: A `tf.string` tensor containing one or more filenames. + record_bytes: A `tf.int64` scalar representing the number of bytes in + each record. + header_bytes: (Optional.) A `tf.int64` scalar representing the number of + bytes to skip at the start of a file. + footer_bytes: (Optional.) A `tf.int64` scalar representing the number of + bytes to ignore at the end of a file. + buffer_size: (Optional.) A `tf.int64` scalar representing the number of + bytes to buffer when reading. + """ + dataset = readers.FixedLengthRecordDataset( + filenames, record_bytes, header_bytes, footer_bytes, buffer_size) + super(FixedLengthRecordDataset, self).__init__(dataset) + + +def read_batch_features(file_pattern, + batch_size, + features, + reader, + reader_args=None, + randomize_input=True, + num_epochs=None, + capacity=10000): + """Reads batches of Examples. + + Example: + + ``` + serialized_examples = [ + features { + feature { key: "age" value { int64_list { value: [ 0 ] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } } + }, + features { + feature { key: "age" value { int64_list { value: [] } } } + feature { key: "gender" value { bytes_list { value: [ "f" ] } } } + feature { key: "kws" value { bytes_list { value: [ "sports" ] } } } + } + ] + ``` + + We can use arguments: + + ``` + features: { + "age": FixedLenFeature([], dtype=tf.int64, default_value=-1), + "gender": FixedLenFeature([], dtype=tf.string), + "kws": VarLenFeature(dtype=tf.string), + } + ``` + + And the expected output is: + + ```python + { + "age": [[0], [-1]], + "gender": [["f"], ["f"]], + "kws": SparseTensor( + indices=[[0, 0], [0, 1], [1, 0]], + values=["code", "art", "sports"] + dense_shape=[2, 2]), + } + ``` + + Args: + file_pattern: List of files or patterns of file paths containing + `Example` records. See `tf.gfile.Glob` for pattern rules. + batch_size: An int representing the number of consecutive elements of this + dataset to combine in a single batch. + features: A `dict` mapping feature keys to `FixedLenFeature` or + `VarLenFeature` values. See `tf.parse_example`. + reader: A function or class that can be called with a `filenames` tensor + and (optional) `reader_args` and returns a `Dataset` of serialized + Examples. + reader_args: Additional arguments to pass to the reader class. + randomize_input: Whether the input should be randomized. + num_epochs: Integer specifying the number of times to read through the + dataset. If None, cycles through the dataset forever. + capacity: Capacity of the ShuffleDataset. A large capacity ensures better + shuffling but would increase memory usage and startup time. + + Returns: + A dict from keys in features to Tensor or SparseTensor objects. + """ + filenames = _get_file_names(file_pattern, randomize_input) + if reader_args: + dataset = reader(filenames, *reader_args) + else: + dataset = reader(filenames) + if dataset.output_types == (dtypes.string, dtypes.string): + dataset = dataset.map(lambda unused_k, v: v) + elif dataset.output_types != dtypes.string: + raise TypeError("`reader` must be a dataset of `tf.string` values, " + "or `(tf.string, tf.string)` key-value pairs.") + if num_epochs != 1: + dataset = dataset.repeat(num_epochs) + if randomize_input: + dataset = dataset.shuffle(capacity) + dataset = dataset.batch(batch_size) + dataset = dataset.map(lambda x: _parse_example(x, features)) + iterator = dataset.make_one_shot_iterator() + outputs = iterator.get_next() + index = 0 + result = {} + for key in sorted(features.keys()): + feature = features[key] + if isinstance(feature, parsing_ops.FixedLenFeature): + result[key] = outputs[index] + index += 1 + else: + result[key] = sparse_tensor_lib.SparseTensor( + indices=outputs[index], + values=outputs[index + 1], + dense_shape=outputs[index + 2]) + index += 3 + return result + + +def _get_file_names(file_pattern, randomize_input): + """Parse list of file names from pattern, optionally shuffled. + + Args: + file_pattern: File glob pattern, or list of glob patterns. + randomize_input: Whether to shuffle the order of file names. + + Returns: + List of file names matching `file_pattern`. + + Raises: + ValueError: If `file_pattern` is empty, or pattern matches no files. + """ + if isinstance(file_pattern, list): + if not file_pattern: + raise ValueError("File pattern is empty.") + file_names = [] + for entry in file_pattern: + file_names.extend(gfile.Glob(entry)) + else: + file_names = list(gfile.Glob(file_pattern)) + + if not file_names: + raise ValueError("No files match %s." % file_pattern) + + # Sort files so it will be deterministic for unit tests. + if not randomize_input: + file_names = sorted(file_names) + return file_names + + +def _parse_example(serialized, features): + parsed = parsing_ops.parse_example(serialized, features) + result = [] + for key in sorted(features.keys()): + val = parsed[key] + if isinstance(val, sparse_tensor_lib.SparseTensor): + result.extend([val.indices, val.values, val.dense_shape]) + else: + result.append(val) + return tuple(result) + + +class SqlDataset(contrib_dataset_ops.Dataset): + + def __init__(self, driver_name, data_source_name, query, output_types): + dataset = _SqlDataset(driver_name, data_source_name, query, output_types) + super(SqlDataset, self).__init__(dataset) + + +class _SqlDataset(dataset_ops.Dataset): + """A `Dataset` consisting of the results from a SQL query.""" + + def __init__(self, driver_name, data_source_name, query, output_types): + """Creates a `SqlDataset`. + + `SqlDataset` allows a user to read data from the result set of a SQL query. + For example: + + ```python + dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3", + "SELECT name, age FROM people", + (tf.string, tf.int32)) + iterator = dataset.make_one_shot_iterator() + next_element = iterator.get_next() + # Prints the rows of the result set of the above query. + while True: + try: + print(sess.run(next_element)) + except tf.errors.OutOfRangeError: + break + ``` + + Args: + driver_name: A 0-D `tf.string` tensor containing the database type. + Currently, the only supported value is 'sqlite'. + data_source_name: A 0-D `tf.string` tensor containing a connection string + to connect to the database. + query: A 0-D `tf.string` tensor containing the SQL query to execute. + output_types: A tuple of `tf.DType` objects representing the types of the + columns returned by `query`. + """ + super(_SqlDataset, self).__init__() + self._driver_name = ops.convert_to_tensor( + driver_name, dtype=dtypes.string, name="driver_name") + self._data_source_name = ops.convert_to_tensor( + data_source_name, dtype=dtypes.string, name="data_source_name") + self._query = ops.convert_to_tensor( + query, dtype=dtypes.string, name="query") + self._output_types = output_types + + def _as_variant_tensor(self): + return gen_dataset_ops.sql_dataset(self._driver_name, + self._data_source_name, self._query, + nest.flatten(self.output_types), + nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return nest.map_structure(lambda _: tensor_shape.TensorShape([]), + self._output_types) + + @property + def output_types(self): + return self._output_types diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py new file mode 100644 index 0000000000000000000000000000000000000000..56f526a330bfbea7305b0754bfd114c5e97db506 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/resampling.py @@ -0,0 +1,188 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Resampling dataset transformations.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.data.python.ops import batching +from tensorflow.contrib.data.python.ops import scan_ops +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +def rejection_resample(class_func, target_dist, initial_dist=None, seed=None): + """A transformation that resamples a dataset to achieve a target distribution. + + **NOTE** Resampling is performed via rejection sampling; some fraction + of the input values will be dropped. + + Args: + class_func: A function mapping an element of the input dataset to a scalar + `tf.int32` tensor. Values should be in `[0, num_classes)`. + target_dist: A floating point type tensor, shaped `[num_classes]`. + initial_dist: (Optional.) A floating point type tensor, shaped + `[num_classes]`. If not provided, the true class distribution is + estimated live in a streaming fashion. + seed: (Optional.) Python integer seed for the resampler. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.data.Dataset.apply}. + """ + + def _apply_fn(dataset): + """Function from `Dataset` to `Dataset` that applies the transformation.""" + dist_estimation_batch_size = 32 + target_dist_t = ops.convert_to_tensor(target_dist, name="initial_dist") + class_values_ds = dataset.map(class_func) + if initial_dist is not None: + initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist") + acceptance_dist = _calculate_acceptance_probs(initial_dist_t, + target_dist_t) + initial_dist_ds = dataset_ops.Dataset.from_tensors( + initial_dist_t).repeat() + acceptance_dist_ds = dataset_ops.Dataset.from_tensors( + acceptance_dist).repeat() + else: + num_classes = (target_dist_t.shape[0].value or + array_ops.shape(target_dist_t)[0]) + smoothing_constant = 10 + initial_examples_per_class_seen = array_ops.fill( + [num_classes], np.int64(smoothing_constant)) + + def update_estimate_and_tile(num_examples_per_class_seen, c): + updated_examples_per_class_seen, dist = _estimate_data_distribution( + c, num_examples_per_class_seen) + tiled_dist = array_ops.tile( + array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1]) + return updated_examples_per_class_seen, tiled_dist + + initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size) + .apply(scan_ops.scan(initial_examples_per_class_seen, + update_estimate_and_tile)) + .apply(batching.unbatch())) + acceptance_dist_ds = initial_dist_ds.map( + lambda initial: _calculate_acceptance_probs(initial, target_dist_t)) + + def maybe_warn_on_large_rejection(accept_dist, initial_dist): + proportion_rejected = math_ops.reduce_sum( + (1 - accept_dist) * initial_dist) + return control_flow_ops.cond( + math_ops.less(proportion_rejected, .5), + lambda: accept_dist, + lambda: logging_ops.Print( # pylint: disable=g-long-lambda + accept_dist, [proportion_rejected, initial_dist, accept_dist], + message="Proportion of examples rejected by sampler is high: ", + summarize=100, + first_n=10)) + + acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds, + initial_dist_ds)) + .map(maybe_warn_on_large_rejection)) + + current_probabilities_ds = dataset_ops.Dataset.zip( + (acceptance_dist_ds, class_values_ds)).map(array_ops.gather) + filtered_ds = ( + dataset_ops.Dataset.zip((class_values_ds, current_probabilities_ds, + dataset)) + .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p)) + return filtered_ds.map(lambda class_value, _, data: (class_value, data)) + + return _apply_fn + + +def _calculate_acceptance_probs(initial_probs, target_probs): + """Calculate the per-class acceptance rates. + + Args: + initial_probs: The class probabilities of the data. + target_probs: The desired class proportion in minibatches. + Returns: + A list of the per-class acceptance probabilities. + + This method is based on solving the following analysis: + + Let F be the probability of a rejection (on any example). + Let p_i be the proportion of examples in the data in class i (init_probs) + Let a_i is the rate the rejection sampler should *accept* class i + Let t_i is the target proportion in the minibatches for class i (target_probs) + + ``` + F = sum_i(p_i * (1-a_i)) + = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1 + ``` + + An example with class `i` will be accepted if `k` rejections occur, then an + example with class `i` is seen by the rejector, and it is accepted. This can + be written as follows: + + ``` + t_i = sum_k=0^inf(F^k * p_i * a_i) + = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1 + = p_i * a_i / sum_j(p_j * a_j) using F from above + ``` + + Note that the following constraints hold: + ``` + 0 <= p_i <= 1, sum_i(p_i) = 1 + 0 <= a_i <= 1 + 0 <= t_i <= 1, sum_i(t_i) = 1 + ``` + + + A solution for a_i in terms of the other variabes is the following: + ```a_i = (t_i / p_i) / max_i[t_i / p_i]``` + """ + # Add tiny to initial_probs to avoid divide by zero. + denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny) + ratio_l = target_probs / denom + + # Calculate list of acceptance probabilities. + max_ratio = math_ops.reduce_max(ratio_l) + return ratio_l / max_ratio + + +def _estimate_data_distribution(c, num_examples_per_class_seen): + """Estimate data distribution as labels are seen. + + Args: + c: The class labels. Type `int32`, shape `[batch_size]`. + num_examples_per_class_seen: Type `int64`, shape `[num_classes]`, + containing counts. + + Returns: + num_examples_per_lass_seen: Updated counts. Type `int64`, shape + `[num_classes]`. + dist: The updated distribution. Type `float32`, shape `[num_classes]`. + """ + num_classes = num_examples_per_class_seen.get_shape()[0].value + # Update the class-count based on what labels are seen in batch. + num_examples_per_class_seen = math_ops.add( + num_examples_per_class_seen, math_ops.reduce_sum( + array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0)) + init_prob_estimate = math_ops.truediv( + num_examples_per_class_seen, + math_ops.reduce_sum(num_examples_per_class_seen)) + dist = math_ops.cast(init_prob_estimate, dtypes.float32) + return num_examples_per_class_seen, dist diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..5acaed48a3d73e93706bdd0b5b2d614b0c565ab7 --- /dev/null +++ b/tensorflow/contrib/data/python/ops/scan_ops.py @@ -0,0 +1,182 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Scan dataset transformation.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.util import nest +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import gen_dataset_ops + + +class _ScanDataset(dataset_ops.Dataset): + """A dataset that scans a function across its input.""" + + def __init__(self, input_dataset, initial_state, scan_func): + """See `scan()` for details.""" + super(_ScanDataset, self).__init__() + self._input_dataset = input_dataset + + with ops.name_scope("initial_state"): + self._initial_state = nest.pack_sequence_as(initial_state, [ + ops.convert_to_tensor(t, name="component_%d" % i) + for i, t in enumerate(nest.flatten(initial_state)) + ]) + + # Compute initial values for the state shapes and types based on + # the initial state. These will be refined by running + # `tf_scan_func` one or more times below. + self._state_shapes = nest.pack_sequence_as( + self._initial_state, + [t.shape for t in nest.flatten(self._initial_state)]) + self._state_types = nest.pack_sequence_as( + self._initial_state, + [t.dtype for t in nest.flatten(self._initial_state)]) + + # Will be populated by calling `tf_scan_func`. + self._output_shapes = None + self._output_types = None + + # Iteratively rerun the scan function until reaching a fixed pont on + # `self._state_shapes`. + need_to_rerun = True + while need_to_rerun: + + flat_state_shapes = nest.flatten(self._state_shapes) + flat_state_types = nest.flatten(self._state_types) + + # Create a list in which `tf_scan_func` will store the s + flat_new_state_shapes = [] + + @function.Defun( + *(flat_state_types + nest.flatten(input_dataset.output_types))) + def tf_scan_func(*args): + """A wrapper for Defun that facilitates shape inference.""" + # Pass in shape information from the state and input_dataset. + for arg, shape in zip( + args, + flat_state_shapes + nest.flatten(input_dataset.output_shapes)): + arg.set_shape(shape) + + pivot = len(flat_state_shapes) + old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) + input_value = nest.pack_sequence_as(input_dataset.output_types, + args[pivot:]) + + ret = scan_func(old_state, input_value) + if not isinstance(ret, collections.Sequence) or len(ret) != 2: + raise TypeError("The scan function must return a pair comprising the " + "new state and the output value.") + new_state, output_value = ret + + flat_new_state = [ + ops.convert_to_tensor(t) for t in nest.flatten(new_state) + ] + flat_output_value = [ + ops.convert_to_tensor(t) for t in nest.flatten(output_value) + ] + + # Extract shape information from the returned values. + flat_new_state_shapes.extend([t.shape for t in flat_new_state]) + self._output_shapes = nest.pack_sequence_as( + output_value, [t.shape for t in flat_output_value]) + + # Extract and validate type information from the returned values. + for t, dtype in zip(flat_new_state, flat_state_types): + if t.dtype != dtype: + raise TypeError( + "The element types for the new state must match the initial " + "state. Expected %s; got %s." % + (self._state_types, nest.pack_sequence_as( + self._state_types, [t.dtype for t in flat_new_state]))) + self._output_types = nest.pack_sequence_as( + output_value, [t.dtype for t in flat_output_value]) + + return flat_new_state + flat_output_value + + # Use the private method that will execute `tf_scan_func` but delay + # adding it to the graph in case we need to rerun the function. + tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access + + weakened_state_shapes = [ + original.most_specific_compatible_shape(new) + for original, new in zip(flat_state_shapes, flat_new_state_shapes) + ] + + need_to_rerun = False + for original_shape, weakened_shape in zip(flat_state_shapes, + weakened_state_shapes): + if original_shape.ndims is not None and ( + weakened_shape.ndims is None or + original_shape.as_list() != weakened_shape.as_list()): + need_to_rerun = True + break + + if need_to_rerun: + # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun + # `tf_scan_func`. + self._state_shapes = nest.pack_sequence_as(self._state_shapes, + weakened_state_shapes) + + self._scan_func = tf_scan_func + + def _as_variant_tensor(self): + input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access + return gen_dataset_ops.scan_dataset( + input_t, + nest.flatten(self._initial_state), + self._scan_func.captured_inputs, + f=self._scan_func, + output_types=nest.flatten(self.output_types), + output_shapes=nest.flatten(self.output_shapes)) + + @property + def output_shapes(self): + return self._output_shapes + + @property + def output_types(self): + return self._output_types + + +def scan(initial_state, scan_func): + """A transformation that scans a function across an input dataset. + + This transformation is a stateful relative of @{tf.data.Dataset.map}. + In addition to mapping `scan_func` across the elements of the input dataset, + `scan()` accumulates one or more state tensors, whose initial values are + `initial_state`. + + Args: + initial_state: A nested structure of tensors, representing the initial state + of the accumulator. + scan_func: A function that maps `(old_state, input_element)` to + `(new_state, output_element). It must take two arguments and return a + pair of nested structures of tensors. The `new_state` must match the + structure of `initial_state`. + + Returns: + A `Dataset` transformation function, which can be passed to + @{tf.contrib.data.Dataset.apply}. + """ + def _apply_fn(dataset): + return _ScanDataset(dataset, initial_state, scan_func) + + return _apply_fn diff --git a/tensorflow/contrib/deprecated/__init__.py b/tensorflow/contrib/deprecated/__init__.py index 79d7eb6f6f53fc1c5478f9f71fb558118cb616d3..7aff045de30ef870948f598fd53199a36e1edfca 100644 --- a/tensorflow/contrib/deprecated/__init__.py +++ b/tensorflow/contrib/deprecated/__init__.py @@ -35,14 +35,14 @@ generated protobufs. Previously, the tag was allowed to be any unique string; it had no relation to the summary op generating it, and no relation to the TensorFlow name system. -This behavior made it very difficult to write reusable that would add +This behavior made it very difficult to write reusable that would add summary ops to the graph. If you had a function to add summary ops, you would -need to pass in a `tf.name_scope`, manually, to that function to create deduplicated -tags. Otherwise your program would fail with a runtime error due to tag -collision. +need to pass in a `tf.name_scope`, manually, to that function to create +deduplicated tags. Otherwise your program would fail with a runtime error due +to tag collision. The new summary APIs under `tf.summary` throw away the "tag" as an independent -concept; instead, the first argument is the node name. So summary tags now +concept; instead, the first argument is the node name. So summary tags now automatically inherit the surrounding `tf.name_scope`, and automatically are deduplicated if there is a conflict. Now however, the only allowed characters are alphanumerics, underscores, and forward slashes. To make @@ -98,9 +98,10 @@ from tensorflow.python.ops.logging_ops import image_summary from tensorflow.python.ops.logging_ops import merge_all_summaries from tensorflow.python.ops.logging_ops import merge_summary from tensorflow.python.ops.logging_ops import scalar_summary -# pylint: enable=unused-import from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long + _allowed_symbols = ['audio_summary', 'histogram_summary', 'image_summary', 'merge_all_summaries', 'merge_summary', 'scalar_summary'] diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 6d326a1c2fff394076a329f1272dc8595ba771e1..145b9495ff40f8095b50d00e576333fdf5d7acdf 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -2,12 +2,15 @@ # Contains ops for statistical distributions (with pdf, cdf, sample, etc...). # APIs here are meant to evolve over time. +package(default_visibility = [ + "//learning/brain/contrib/bayesflow:__subpackages__", + "//tensorflow:__subpackages__", +]) + licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) -package(default_visibility = ["//tensorflow:__subpackages__"]) - load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( @@ -18,14 +21,20 @@ py_library( "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", + "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", + "//tensorflow/python:template", "//tensorflow/python:tensor_util", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", ], ) @@ -55,7 +64,9 @@ py_library( "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/linalg", "//third_party/py/numpy", "@six_archive//:six", ], @@ -129,6 +140,23 @@ cuda_py_test( ], ) +cuda_py_test( + name = "cauchy_test", + size = "medium", + srcs = ["python/kernel_tests/cauchy_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + cuda_py_test( name = "chi2_test", srcs = ["python/kernel_tests/chi2_test.py"], @@ -298,6 +326,19 @@ cuda_py_test( ], ) +cuda_py_test( + name = "mixture_same_family_test", + size = "small", + srcs = ["python/kernel_tests/mixture_same_family_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:client_testlib", + ], +) + cuda_py_test( name = "negative_binomial_test", size = "small", @@ -339,6 +380,34 @@ cuda_py_test( ], ) +cuda_py_test( + name = "sinh_arcsinh_test", + size = "small", + srcs = ["python/kernel_tests/sinh_arcsinh_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "independent_test", + size = "small", + srcs = ["python/kernel_tests/independent_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sample_stats_test", size = "medium", @@ -357,6 +426,20 @@ cuda_py_test( tags = ["nomsan"], # disable to avoid false positives from scipy. ) +cuda_py_test( + name = "vector_sinh_arcsinh_diag_test", + size = "medium", + srcs = ["python/kernel_tests/vector_sinh_arcsinh_diag_test.py"], + additional_deps = [ + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "vector_exponential_diag_test", size = "medium", @@ -627,6 +710,24 @@ cuda_py_test( ], ) +cuda_py_test( + name = "absolute_value_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/absolute_value_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "affine_test", size = "large", @@ -724,6 +825,25 @@ cuda_py_test( ], ) +cuda_py_test( + name = "gumbel_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/gumbel_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "@six_archive//:six", + "//tensorflow/contrib/linalg:linalg_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "inline_test", size = "small", @@ -762,6 +882,38 @@ cuda_py_test( ], ) +cuda_py_test( + name = "masked_autoregressive_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/masked_autoregressive_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( + name = "permute_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/permute_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "power_transform_test", size = "small", @@ -781,6 +933,22 @@ cuda_py_test( ], ) +cuda_py_test( + name = "reshape_test", + size = "small", + srcs = ["python/kernel_tests/bijectors/reshape_test.py"], + additional_deps = [ + ":bijectors_py", + ":distributions_py", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_test( name = "sigmoid_test", size = "small", @@ -819,10 +987,12 @@ cuda_py_test( ], ) +# Tests for SinhArcSinh bijector. The file name has the extra "_bijector" to +# avoid BUILD rule name conflicts with the distribution by the same name. cuda_py_test( - name = "sinh_arcsinh_test", + name = "sinh_arcsinh_bijector_test", size = "small", - srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_test.py"], + srcs = ["python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py"], additional_deps = [ ":bijectors_py", ":distributions_py", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index ed2a137429768ad4d23c60fd42e2ea45f1b20269..0d12d838932e3a46e07f4a4242b889296c6e13c4 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -24,17 +24,23 @@ from __future__ import print_function from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops.binomial import * +from tensorflow.contrib.distributions.python.ops.cauchy import * from tensorflow.contrib.distributions.python.ops.chi2 import * from tensorflow.contrib.distributions.python.ops.conditional_distribution import * from tensorflow.contrib.distributions.python.ops.conditional_transformed_distribution import * from tensorflow.contrib.distributions.python.ops.deterministic import * +from tensorflow.contrib.distributions.python.ops.distribution_util import fill_triangular from tensorflow.contrib.distributions.python.ops.distribution_util import matrix_diag_transform +from tensorflow.contrib.distributions.python.ops.distribution_util import reduce_weighted_logsumexp from tensorflow.contrib.distributions.python.ops.distribution_util import softplus_inverse +from tensorflow.contrib.distributions.python.ops.distribution_util import tridiag from tensorflow.contrib.distributions.python.ops.estimator import * from tensorflow.contrib.distributions.python.ops.geometric import * +from tensorflow.contrib.distributions.python.ops.independent import * from tensorflow.contrib.distributions.python.ops.inverse_gamma import * from tensorflow.contrib.distributions.python.ops.logistic import * from tensorflow.contrib.distributions.python.ops.mixture import * +from tensorflow.contrib.distributions.python.ops.mixture_same_family import * from tensorflow.contrib.distributions.python.ops.moving_stats import * from tensorflow.contrib.distributions.python.ops.mvn_diag import * from tensorflow.contrib.distributions.python.ops.mvn_diag_plus_low_rank import * @@ -49,10 +55,12 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import * from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import * from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import * from tensorflow.contrib.distributions.python.ops.sample_stats import * +from tensorflow.contrib.distributions.python.ops.sinh_arcsinh import * from tensorflow.contrib.distributions.python.ops.test_util import * from tensorflow.contrib.distributions.python.ops.vector_diffeomixture import * from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import * from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import * +from tensorflow.contrib.distributions.python.ops.vector_sinh_arcsinh_diag import * from tensorflow.contrib.distributions.python.ops.wishart import * from tensorflow.python.ops.distributions.bernoulli import * from tensorflow.python.ops.distributions.beta import * @@ -76,23 +84,11 @@ from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ 'bijectors', + 'Cauchy', 'ConditionalDistribution', 'ConditionalTransformedDistribution', 'FULLY_REPARAMETERIZED', 'NOT_REPARAMETERIZED', - 'Affine', - 'AffineLinearOperator', - 'Bijector', - 'Chain', - 'CholeskyOuterProduct', - 'Exp', - 'Identity', - 'Inline', - 'Invert', - 'PowerTransform', - 'SigmoidCentered', - 'SoftmaxCentered', - 'Softplus', 'ReparameterizationType', 'Distribution', 'Binomial', @@ -111,6 +107,7 @@ _allowed_symbols = [ 'Gamma', 'GammaWithSoftplusConcentrationRate', 'Geometric', + 'Independent', 'InverseGamma', 'InverseGammaWithSoftplusConcentrationRate', 'Laplace', @@ -121,6 +118,7 @@ _allowed_symbols = [ 'NormalWithSoftplusScale', 'Poisson', 'PoissonLogNormalQuadratureCompound', + 'SinhArcsinh', 'StudentT', 'StudentTWithAbsDfSoftplusScale', 'Uniform', @@ -134,24 +132,27 @@ _allowed_symbols = [ 'Multinomial', 'VectorDiffeomixture', 'VectorLaplaceDiag', + 'VectorSinhArcsinhDiag', 'WishartCholesky', 'WishartFull', 'TransformedDistribution', 'QuantizedDistribution', 'Mixture', + 'MixtureSameFamily', 'ExpRelaxedOneHotCategorical', 'OneHotCategorical', 'RelaxedBernoulli', 'RelaxedOneHotCategorical', 'kl_divergence', 'RegisterKL', - 'matrix_diag_transform', 'fill_triangular', + 'matrix_diag_transform', + 'reduce_weighted_logsumexp', + 'softplus_inverse', + 'tridiag', 'normal_conjugates_known_scale_posterior', 'normal_conjugates_known_scale_predictive', - 'softplus_inverse', 'percentile', - 'reduce_weighted_logsumexp', 'assign_moving_mean_variance', 'assign_log_moving_mean_exp', 'moving_mean_variance', diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e0d65c79b2654c2949de161d6317f218d11cab43 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/absolute_value_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for AbsoluteValue Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +# pylint: disable=g-importing-member +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import AbsoluteValue +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + +# pylint: enable=g-importing-member + + +class AbsoluteValueTest(test.TestCase): + """Tests correctness of the absolute value bijector.""" + + def testBijectorVersusNumpyRewriteOfBasicFunctionsEventNdims0(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + self.assertEqual("absolute_value", bijector.name) + x = array_ops.constant([[0., 1., -1], [0., -5., 3.]]) # Shape [2, 3] + y = math_ops.abs(x) + + y_ = y.eval() + zeros = np.zeros((2, 3)) + + self.assertAllClose(y_, bijector.forward(x).eval()) + self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) + self.assertAllClose((zeros, zeros), + sess.run(bijector.inverse_log_det_jacobian(y))) + + # Run things twice to make sure there are no issues in caching the tuples + # returned by .inverse* + self.assertAllClose(y_, bijector.forward(x).eval()) + self.assertAllClose((-y_, y_), sess.run(bijector.inverse(y))) + self.assertAllClose((zeros, zeros), + sess.run(bijector.inverse_log_det_jacobian(y))) + + def testEventNdimsMustBeZeroOrRaiseStatic(self): + with self.test_session(): + with self.assertRaisesRegexp(ValueError, "event_ndims.*was not 0"): + AbsoluteValue(event_ndims=1) + + def testEventNdimsMustBeZeroOrRaiseDynamic(self): + with self.test_session() as sess: + event_ndims = array_ops.placeholder(dtypes.int32) + abs_bijector = AbsoluteValue(event_ndims=event_ndims, validate_args=True) + with self.assertRaisesOpError("event_ndims was not 0"): + sess.run(abs_bijector.inverse_log_det_jacobian([1.]), + feed_dict={event_ndims: 1}) + + def testNegativeYRaisesForInverseIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse(-1.)) + + def testNegativeYRaisesForILDJIfValidateArgs(self): + with self.test_session() as sess: + bijector = AbsoluteValue(event_ndims=0, validate_args=True) + with self.assertRaisesOpError("y was negative"): + sess.run(bijector.inverse_log_det_jacobian(-1.)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py index 0738754b217e5842bd0fa516915f14926083d321..405ddd292cacd8ace87d6caeebf3e8cfc347c22d 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_linear_operator_test.py @@ -72,7 +72,7 @@ class AffineLinearOperatorTest(test.TestCase): [3, -2, 0], [4, 3, 2]]], dtype=np.float32) - scale = linalg.LinearOperatorTriL(tril, is_non_singular=True) + scale = linalg.LinearOperatorLowerTriangular(tril, is_non_singular=True) affine = AffineLinearOperator( shift=shift, scale=scale, validate_args=True) diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py index 2c4b8277d01c7a2929fdde7babf809f2c16f730b..c9158117f7a982e37047e8dd2b534a30040a87d9 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/affine_test.py @@ -76,7 +76,7 @@ class AffineBijectorTest(test.TestCase): for run in (static_run, dynamic_run): mu = -1. # Corresponds to scale = 2 - bijector = Affine(shift=mu, scale_diag=[2.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=2., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) @@ -84,7 +84,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose(-np.log(2.), run(bijector.inverse_log_det_jacobian, x)) - def testWeirdSampleNoBatchScalarViaIdentity(self): + def testWeirdSampleNoBatchScalarViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -156,7 +156,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([np.log(0.5)], run(bijector.inverse_log_det_jacobian, x)) - def testOneBatchScalarViaDiag(self): + def testOneBatchScalarViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -171,7 +171,7 @@ class AffineBijectorTest(test.TestCase): mu = [1.] # One batch, scalar. # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_diag=[1.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1.] # One sample from one batches. self.assertAllClose([2.], run(bijector.forward, x)) @@ -200,7 +200,7 @@ class AffineBijectorTest(test.TestCase): self.assertAllClose([0., 2], run(bijector.inverse, x)) self.assertAllClose(0., run(bijector.inverse_log_det_jacobian, x)) - def testTwoBatchScalarIdentityViaDiag(self): + def testTwoBatchScalarIdentityViaDiagMultiplier(self): with self.test_session() as sess: def static_run(fun, x): @@ -215,7 +215,7 @@ class AffineBijectorTest(test.TestCase): mu = [1., -1] # Univariate, two batches. # Corresponds to scale = 1. - bijector = Affine(shift=mu, scale_diag=[1.], event_ndims=0) + bijector = Affine(shift=mu, scale_identity_multiplier=1., event_ndims=0) self.assertEqual(0, bijector.event_ndims.eval()) # "is scalar" x = [1., 1] # One sample from each of two batches. self.assertAllClose([2., 0], run(bijector.forward, x)) @@ -410,13 +410,13 @@ class AffineBijectorTest(test.TestCase): bijector = Affine( shift=mu, scale_identity_multiplier=1., - scale_diag=[1.], - event_ndims=0) - self.assertEqual(0, bijector.event_ndims.eval()) # "is vector" + scale_diag=[1., 1., 1.], + event_ndims=1) + self.assertEqual(1, bijector.event_ndims.eval()) # "is vector" x = [1., 2, 3] # Three scalar samples (no batches). self.assertAllClose([1., 3, 5], run(bijector.forward, x)) self.assertAllClose([1., 1.5, 2.], run(bijector.inverse, x)) - self.assertAllClose(-np.log(2.), + self.assertAllClose(-np.log(2.**3), run(bijector.inverse_log_det_jacobian, x)) def testIdentityWithTriL(self): @@ -668,11 +668,10 @@ class AffineBijectorTest(test.TestCase): with self.assertRaisesOpError("identity_multiplier should be non-zero"): bijector.forward(1.).eval() - # Check Diag matrix with zero scaling. - bijector = Affine( - shift=mu, scale_diag=[0.0], event_ndims=0, validate_args=True) - with self.assertRaisesOpError("diagonal part must be non-zero"): - bijector.forward(1.).eval() + def testScaleDiagAndEventNdimsZeroRaises(self): + # Check Diag matrix with zero scaling. + with self.assertRaisesRegexp(ValueError, "only scale argument"): + Affine(shift=None, scale_diag=[0.0], event_ndims=0, validate_args=True) def testScalarCongruency(self): with self.test_session(): @@ -830,6 +829,15 @@ class AffineBijectorTest(test.TestCase): x=np.array( [1., 2], dtype=np.float32)) + def testScalarEventIdentityScale(self): + with self.test_session() as sess: + doubler = Affine( + scale_identity_multiplier=2., + event_ndims=0) + doubler2 = doubler.inverse_log_det_jacobian(2.) + doubler2_ildj_ = sess.run([doubler2]) + self.assertAllClose([-np.log(2.)], doubler2_ildj_) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9a905980c7581a86bbcda8c6c726da57c09fe4f8 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/gumbel_test.py @@ -0,0 +1,70 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from scipy import stats + +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import Gumbel +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.ops.distributions.bijector_test_util import assert_scalar_congruency +from tensorflow.python.platform import test + + +class GumbelBijectorTest(test.TestCase): + """Tests correctness of the Gumbel bijector.""" + + def testBijector(self): + with self.test_session(): + loc = 0.3 + scale = 5. + bijector = Gumbel(loc=loc, scale=scale, event_ndims=1, validate_args=True) + self.assertEqual("gumbel", bijector.name) + x = np.array([[[-3.], [0.], [0.5], [4.2], [12.]]], dtype=np.float32) + # Gumbel distribution + gumbel_dist = stats.gumbel_r(loc=loc, scale=scale) + y = gumbel_dist.cdf(x).astype(np.float32) + self.assertAllClose(y, bijector.forward(x).eval()) + self.assertAllClose(x, bijector.inverse(y).eval()) + self.assertAllClose( + # We should lose a dimension from calculating the determinant of the + # jacobian. + np.squeeze(gumbel_dist.logpdf(x), axis=2), + bijector.forward_log_det_jacobian(x).eval()) + self.assertAllClose( + -bijector.inverse_log_det_jacobian(y).eval(), + bijector.forward_log_det_jacobian(x).eval(), + rtol=1e-4, + atol=0.) + + def testScalarCongruency(self): + with self.test_session(): + assert_scalar_congruency( + Gumbel(loc=0.3, scale=20.), lower_x=1., upper_x=100., rtol=0.02) + + def testBijectiveAndFinite(self): + with self.test_session(): + bijector = Gumbel(loc=0., scale=3.0, event_ndims=0, validate_args=True) + x = np.linspace(-10., 10., num=10).astype(np.float32) + y = np.linspace(0.01, 0.99, num=10).astype(np.float32) + assert_bijective_and_finite(bijector, x, y, rtol=1e-3) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py new file mode 100644 index 0000000000000000000000000000000000000000..25a9b6f5fe2ed6d218d6b44650fce17fa89c0664 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/masked_autoregressive_test.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MaskedAutoregressiveFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.contrib.distributions.python.ops.bijectors.invert import Invert +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import masked_autoregressive_default_template +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import MaskedAutoregressiveFlow +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import _gen_mask +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.distributions import transformed_distribution as transformed_distribution_lib +from tensorflow.python.platform import test + + +class GenMaskTest(test.TestCase): + + def test346Exclusive(self): + expected_mask = np.array( + [[0, 0, 0, 0], + [0, 0, 0, 0], + [1, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 0, 0]]) + mask = _gen_mask(num_blocks=3, n_in=4, n_out=6, mask_type="exclusive") + self.assertAllEqual(expected_mask, mask) + + def test346Inclusive(self): + expected_mask = np.array( + [[1, 0, 0, 0], + [1, 0, 0, 0], + [1, 1, 0, 0], + [1, 1, 0, 0], + [1, 1, 1, 0], + [1, 1, 1, 0]]) + mask = _gen_mask(num_blocks=3, n_in=4, n_out=6, mask_type="inclusive") + self.assertAllEqual(expected_mask, mask) + + +class MaskedAutoregressiveFlowTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=False), + "is_constant_jacobian": False, + } + + def testBijector(self): + x_ = np.arange(3 * 4 * 2).astype(np.float32).reshape(3, 4, 2) + with self.test_session() as sess: + ma = MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs) + x = constant_op.constant(x_) + forward_x = ma.forward(x) + # Use identity to invalidate cache. + inverse_y = ma.inverse(array_ops.identity(forward_x)) + fldj = ma.forward_log_det_jacobian(x) + # Use identity to invalidate cache. + ildj = ma.inverse_log_det_jacobian(array_ops.identity(forward_x)) + variables.global_variables_initializer().run() + [ + forward_x_, + inverse_y_, + ildj_, + fldj_, + ] = sess.run([ + forward_x, + inverse_y, + ildj, + fldj, + ]) + self.assertEqual("masked_autoregressive_flow", ma.name) + self.assertAllClose(forward_x_, forward_x_, rtol=1e-6, atol=0.) + self.assertAllClose(x_, inverse_y_, rtol=1e-5, atol=0.) + self.assertAllClose(ildj_, -fldj_, rtol=1e-6, atol=0.) + + def testMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + ma = MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ma, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + def testInvertMutuallyConsistent(self): + dims = 4 + with self.test_session() as sess: + ma = Invert(MaskedAutoregressiveFlow( + validate_args=True, + **self._autoregressive_flow_kwargs)) + dist = transformed_distribution_lib.TransformedDistribution( + distribution=normal_lib.Normal(loc=0., scale=1.), + bijector=ma, + event_shape=[dims], + validate_args=True) + self.run_test_sample_consistent_log_prob( + sess_run_fn=sess.run, + dist=dist, + num_samples=int(1e5), + radius=1., + center=0., + rtol=0.02) + + +class MaskedAutoregressiveFlowShiftOnlyTest(MaskedAutoregressiveFlowTest): + + @property + def _autoregressive_flow_kwargs(self): + return { + "shift_and_log_scale_fn": masked_autoregressive_default_template( + hidden_layers=[2], shift_only=True), + "is_constant_jacobian": True, + } + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py new file mode 100644 index 0000000000000000000000000000000000000000..54590de373441c32cc3214cb04d45cfc2d1807ed --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/permute_test.py @@ -0,0 +1,87 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.permute import Permute +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class PermuteBijectorTest(test.TestCase): + """Tests correctness of the Permute bijector.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + expected_permutation = np.int32([2, 0, 1]) + expected_x = np.random.randn(4, 2, 3) + expected_y = expected_x[..., expected_permutation] + + with self.test_session() as sess: + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + [ + permutation_, + x_, + y_, + fldj, + ildj, + ] = sess.run([ + bijector.permutation, + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + ], feed_dict={permutation_ph: expected_permutation}) + self.assertEqual("permute", bijector.name) + self.assertAllEqual(expected_permutation, permutation_) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + with self.test_session() as sess: + with self.assertRaisesOpError("Permutation over `d` must contain"): + permutation_ph = array_ops.placeholder(dtype=dtypes.int32) + bijector = Permute( + permutation=permutation_ph, + validate_args=True) + sess.run(bijector.inverse([1.]), + feed_dict={permutation_ph: [1, 2]}) + + def testBijectiveAndFinite(self): + permutation = np.int32([2, 0, 1]) + x = np.random.randn(4, 2, 3) + y = x[..., permutation] + with self.test_session(): + bijector = Permute( + permutation=permutation, + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py new file mode 100644 index 0000000000000000000000000000000000000000..38b3a23c2d684a6f89b7c4be4a763c649bf4de15 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Reshape Bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import Reshape +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops.distributions.bijector_test_util import assert_bijective_and_finite +from tensorflow.python.platform import test + + +class ReshapeBijectorTest(test.TestCase): + """Tests correctness of the reshape transformation.""" + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testBijector(self): + """Do a basic sanity check of forward, inverse, jacobian.""" + expected_x = np.random.randn(4, 3, 2) + expected_y = np.reshape(expected_x, [4, 6]) + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[6,], + event_shape_in=[3, 2], + validate_args=True) + (x_, + y_, + fldj_, + ildj_) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.forward_log_det_jacobian(expected_x), + bijector.inverse_log_det_jacobian(expected_y), + )) + self.assertEqual("reshape", bijector.name) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(0., fldj_, rtol=1e-6, atol=0) + self.assertAllClose(0., ildj_, rtol=1e-6, atol=0) + + def testEventShapeDynamicNdims(self): + """Check forward/inverse shape methods with dynamic ndims.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_ph = array_ops.placeholder(dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_ph = array_ops.placeholder(dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, validate_args=True) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeDynamic(self): + """Check shape methods with static ndims but dynamic shape.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_in_partial = tensor_shape.TensorShape([None,]) + shape_in_ph = array_ops.placeholder( + shape=[1,], dtype=dtypes.int32) + + shape_out = tensor_shape.TensorShape([2, 3]) + shape_out_partial = tensor_shape.TensorShape([None, None]) + shape_out_ph = array_ops.placeholder( + shape=[2,], dtype=dtypes.int32) + + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + # if event shapes are not statically available, should + # return partially-specified TensorShapes. + self.assertAllEqual( + bijector.forward_event_shape(shape_in).as_list(), + shape_out_partial.as_list()) + self.assertAllEqual( + bijector.inverse_event_shape(shape_out).as_list(), + shape_in_partial.as_list()) + + # using the _tensor methods, we should always get a fully-specified + # result since these are evaluated at graph runtime. + with self.test_session() as sess: + (shape_out_, + shape_in_) = sess.run(( + bijector.forward_event_shape_tensor(shape_in), + bijector.inverse_event_shape_tensor(shape_out), + ), feed_dict={ + shape_in_ph: shape_in, + shape_out_ph: shape_out, + }) + self.assertAllEqual(shape_out, shape_out_) + self.assertAllEqual(shape_in, shape_in_) + + def testEventShapeStatic(self): + """Check shape methods when shape is statically known.""" + + shape_in = tensor_shape.TensorShape([6,]) + shape_out = tensor_shape.TensorShape([2, 3]) + + bijector_static = Reshape( + event_shape_out=shape_out, + event_shape_in=shape_in, + validate_args=True) + + # test that forward_ and inverse_event_shape do sensible things + # when shapes are statically known. + self.assertEqual( + bijector_static.forward_event_shape(shape_in), + shape_out) + self.assertEqual( + bijector_static.inverse_event_shape(shape_out), + shape_in) + + with self.test_session() as sess: + (shape_out_static_, + shape_in_static_, + ) = sess.run(( + bijector_static.forward_event_shape_tensor(shape_in), + bijector_static.inverse_event_shape_tensor(shape_out), + )) + self.assertAllEqual(shape_out, shape_out_static_) + self.assertAllEqual(shape_in, shape_in_static_) + + def testScalarReshape(self): + """Test reshaping to and from a scalar shape ().""" + + expected_x = np.random.randn(4, 3, 1) + expected_y = np.reshape(expected_x, [4, 3]) + + expected_x_scalar = np.random.randn(1,) + expected_y_scalar = expected_x_scalar[0] + + with self.test_session() as sess: + bijector = Reshape( + event_shape_out=[], + event_shape_in=[1,], validate_args=True) + + (x_, + y_, + x_scalar_, + y_scalar_ + ) = sess.run(( + bijector.inverse(expected_y), + bijector.forward(expected_x), + bijector.inverse(expected_y_scalar), + bijector.forward(expected_x_scalar), + )) + self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0) + self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0) + self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0) + + def testRaisesOpError(self): + x1 = np.random.randn(4, 2, 3) + x2 = np.random.randn(4, 3, 2) + x3 = np.random.randn(4, 5, 1, 1) + + with self.test_session() as sess: + shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32) + shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32) + bijector = Reshape( + event_shape_out=shape_out_ph, + event_shape_in=shape_in_ph, + validate_args=True) + + with self.assertRaisesOpError( + "Input `event_shape` does not match `event_shape_in`."): + sess.run(bijector.forward(x2), + feed_dict={shape_out_ph: [1, 6, 1], + shape_in_ph: [2, 3]}) + + with self.assertRaisesOpError( + "event_shape_out entries must be positive."): + sess.run(bijector.forward(x1), + feed_dict={shape_out_ph: [-1, -1, 6], + shape_in_ph: [2, 3]}) + + # test that *all* methods check basic assertions + fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]} + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward(x1), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse(x3), feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.inverse_log_det_jacobian(x3), + feed_dict=fd_mismatched) + with self.assertRaisesOpError( + "Input/output `event_size`s do not match."): + sess.run(bijector.forward_log_det_jacobian(x1), + feed_dict=fd_mismatched) + + def testBijectiveAndFinite(self): + x = np.random.randn(4, 2, 3) + y = np.reshape(x, [4, 1, 2, 3]) + with self.test_session(): + bijector = Reshape( + event_shape_in=[2, 3], + event_shape_out=[1, 2, 3], + validate_args=True) + assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0) + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py similarity index 96% rename from tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py rename to tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py index 230dd93a2a807cc14394e3c747c208c1f95b194d..172c180a44229089f06f250a872bc47a89991cf0 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_bijector_test.py @@ -41,7 +41,7 @@ class SinhArcsinhBijectorTest(test.TestCase): tailweight=tailweight, event_ndims=1, validate_args=True) - self.assertEqual("sinh_arcsinh", bijector.name) + self.assertEqual("SinhArcsinh", bijector.name) x = np.array([[[-2.01], [2.], [1e-4]]]).astype(np.float32) y = np.sinh((np.arcsinh(x) + skewness) * tailweight) self.assertAllClose(y, bijector.forward(x).eval()) @@ -170,6 +170,12 @@ class SinhArcsinhBijectorTest(test.TestCase): with self.assertRaisesOpError("not positive"): SinhArcsinh(tailweight=0., validate_args=True).forward(1.0).eval() + def testDefaultDtypeIsFloat32(self): + with self.test_session(): + bijector = SinhArcsinh() + self.assertEqual(bijector.tailweight.dtype, np.float32) + self.assertEqual(bijector.skewness.dtype, np.float32) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7697357ce7c77b2a50b87271d4ba7b49cbe05e --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/cauchy_test.py @@ -0,0 +1,437 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Cauchy.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import cauchy as cauchy_lib +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class CauchyTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(123) + + def assertAllFinite(self, tensor): + is_finite = np.isfinite(tensor.eval()) + all_true = np.ones_like(is_finite, dtype=np.bool) + self.assertAllEqual(all_true, is_finite) + + def _testParamShapes(self, sample_shape, expected): + with self.test_session(): + param_shapes = cauchy_lib.Cauchy.param_shapes(sample_shape) + loc_shape, scale_shape = param_shapes["loc"], param_shapes["scale"] + self.assertAllEqual(expected, loc_shape.eval()) + self.assertAllEqual(expected, scale_shape.eval()) + loc = array_ops.zeros(loc_shape) + scale = array_ops.ones(scale_shape) + self.assertAllEqual( + expected, + array_ops.shape(cauchy_lib.Cauchy(loc, scale).sample()).eval()) + + def _testParamStaticShapes(self, sample_shape, expected): + param_shapes = cauchy_lib.Cauchy.param_static_shapes(sample_shape) + loc_shape, scale_shape = param_shapes["loc"], param_shapes["scale"] + self.assertEqual(expected, loc_shape) + self.assertEqual(expected, scale_shape) + + def testParamShapes(self): + sample_shape = [10, 3, 4] + self._testParamShapes(sample_shape, sample_shape) + self._testParamShapes(constant_op.constant(sample_shape), sample_shape) + + def testParamStaticShapes(self): + sample_shape = [10, 3, 4] + self._testParamStaticShapes(sample_shape, sample_shape) + self._testParamStaticShapes( + tensor_shape.TensorShape(sample_shape), sample_shape) + + def testCauchyLogPDF(self): + with self.test_session(): + batch_size = 6 + loc = constant_op.constant([3.0] * batch_size) + scale = constant_op.constant([np.sqrt(10.0)] * batch_size) + x = np.array([-2.5, 2.5, 4.0, 0.0, -1.0, 2.0], dtype=np.float32) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + log_pdf = cauchy.log_prob(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.eval().shape) + + pdf = cauchy.prob(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, pdf.shape) + self.assertAllEqual(cauchy.batch_shape, pdf.eval().shape) + + if not stats: + return + expected_log_pdf = stats.cauchy(loc.eval(), scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf.eval()) + self.assertAllClose(np.exp(expected_log_pdf), pdf.eval()) + + def testCauchyLogPDFMultidimensional(self): + with self.test_session(): + batch_size = 6 + loc = constant_op.constant([[3.0, -3.0]] * batch_size) + scale = constant_op.constant([[np.sqrt(10.0), np.sqrt(15.0)]] * + batch_size) + x = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + log_pdf = cauchy.log_prob(x) + log_pdf_values = log_pdf.eval() + self.assertEqual(log_pdf.shape, (6, 2)) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + log_pdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.shape) + self.assertAllEqual(cauchy.batch_shape, log_pdf.eval().shape) + + pdf = cauchy.prob(x) + pdf_values = pdf.eval() + self.assertEqual(pdf.shape, (6, 2)) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), pdf_values.shape) + self.assertAllEqual(cauchy.batch_shape, pdf.shape) + self.assertAllEqual(cauchy.batch_shape, pdf_values.shape) + + if not stats: + return + expected_log_pdf = stats.cauchy(loc.eval(), scale.eval()).logpdf(x) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + + def testCauchyCDF(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + cdf = cauchy.cdf(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, cdf.shape) + self.assertAllEqual(cauchy.batch_shape, cdf.eval().shape) + if not stats: + return + expected_cdf = stats.cauchy(loc, scale).cdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0) + + def testCauchySurvivalFunction(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-8.0, 8.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + sf = cauchy.survival_function(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, sf.shape) + self.assertAllEqual(cauchy.batch_shape, sf.eval().shape) + if not stats: + return + expected_sf = stats.cauchy(loc, scale).sf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0) + + def testCauchyLogCDF(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-100.0, 10.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + cdf = cauchy.log_cdf(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), cdf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, cdf.shape) + self.assertAllEqual(cauchy.batch_shape, cdf.eval().shape) + + if not stats: + return + expected_cdf = stats.cauchy(loc, scale).logcdf(x) + self.assertAllClose(expected_cdf, cdf.eval(), atol=0, rtol=1e-5) + + def testFiniteGradientAtDifficultPoints(self): + for dtype in [np.float32, np.float64]: + g = ops.Graph() + with g.as_default(): + loc = variables.Variable(dtype(0.0)) + scale = variables.Variable(dtype(1.0)) + dist = cauchy_lib.Cauchy(loc=loc, scale=scale) + x = np.array([-100., -20., -5., 0., 5., 20., 100.]).astype(dtype) + for func in [ + dist.cdf, dist.log_cdf, dist.survival_function, + dist.log_survival_function, dist.log_prob, dist.prob + ]: + value = func(x) + grads = gradients_impl.gradients(value, [loc, scale]) + with self.test_session(graph=g): + variables.global_variables_initializer().run() + self.assertAllFinite(value) + self.assertAllFinite(grads[0]) + self.assertAllFinite(grads[1]) + + def testCauchyLogSurvivalFunction(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + x = np.linspace(-10.0, 100.0, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + sf = cauchy.log_survival_function(x) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), sf.eval().shape) + self.assertAllEqual(cauchy.batch_shape, sf.shape) + self.assertAllEqual(cauchy.batch_shape, sf.eval().shape) + + if not stats: + return + expected_sf = stats.cauchy(loc, scale).logsf(x) + self.assertAllClose(expected_sf, sf.eval(), atol=0, rtol=1e-5) + + def testCauchyEntropy(self): + with self.test_session(): + loc = np.array([1.0, 1.0, 1.0]) + scale = np.array([[1.0, 2.0, 3.0]]) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + entropy = cauchy.entropy() + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + entropy.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), + entropy.eval().shape) + self.assertAllEqual(cauchy.batch_shape, entropy.shape) + self.assertAllEqual(cauchy.batch_shape, entropy.eval().shape) + + if not stats: + return + expected_entropy = stats.cauchy(loc, scale).entropy() + self.assertAllClose(expected_entropy, entropy.eval()) + + def testCauchyMode(self): + with self.test_session(): + # Mu will be broadcast to [7, 7, 7]. + loc = [7.] + scale = [11., 12., 13.] + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.mode().shape) + self.assertAllEqual([7., 7, 7], cauchy.mode().eval()) + + def testCauchyMean(self): + with self.test_session(): + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.mean().shape) + self.assertAllEqual([np.nan] * 3, cauchy.mean().eval()) + + def testCauchyNanMean(self): + with self.test_session(): + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.mean().eval() + + def testCauchyQuantile(self): + with self.test_session(): + batch_size = 50 + loc = self._rng.randn(batch_size) + scale = self._rng.rand(batch_size) + 1.0 + p = np.linspace(0.000001, 0.999999, batch_size).astype(np.float64) + + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + x = cauchy.quantile(p) + + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), x.shape) + self.assertAllEqual(cauchy.batch_shape_tensor().eval(), x.eval().shape) + self.assertAllEqual(cauchy.batch_shape, x.shape) + self.assertAllEqual(cauchy.batch_shape, x.eval().shape) + + if not stats: + return + expected_x = stats.cauchy(loc, scale).ppf(p) + self.assertAllClose(expected_x, x.eval(), atol=0.) + + def testCauchyVariance(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.variance().shape) + self.assertAllEqual([np.nan] * 3, cauchy.variance().eval()) + + def testCauchyNanVariance(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.variance().eval() + + def testCauchyStandardDeviation(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertAllEqual((3,), cauchy.stddev().shape) + self.assertAllEqual([np.nan] * 3, cauchy.stddev().eval()) + + def testCauchyNanStandardDeviation(self): + with self.test_session(): + # scale will be broadcast to [7, 7, 7] + loc = [1., 2., 3.] + scale = [7.] + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale, allow_nan_stats=False) + + with self.assertRaises(ValueError): + cauchy.stddev().eval() + + def testCauchySample(self): + with self.test_session(): + loc = constant_op.constant(3.0) + scale = constant_op.constant(1.0) + loc_v = 3.0 + n = constant_op.constant(100000) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + samples = cauchy.sample(n) + sample_values = samples.eval() + + self.assertEqual(sample_values.shape, (100000,)) + self.assertAllClose(np.median(sample_values), loc_v, atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) + + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + expected_shape = (tensor_shape.TensorShape( + [n.eval()]).concatenate(cauchy.batch_shape)) + + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + def testCauchySampleMultiDimensional(self): + with self.test_session(): + batch_size = 2 + loc = constant_op.constant([[3.0, -3.0]] * batch_size) + scale = constant_op.constant([[0.5, 1.0]] * batch_size) + loc_v = [3.0, -3.0] + n = constant_op.constant(100000) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + samples = cauchy.sample(n) + sample_values = samples.eval() + self.assertEqual(samples.shape, (100000, batch_size, 2)) + self.assertAllClose(np.median(sample_values[:, 0, 0]), + loc_v[0], atol=1e-1) + self.assertAllClose(np.median(sample_values[:, 0, 1]), + loc_v[1], atol=1e-1) + + expected_shape = tensor_shape.TensorShape([n.eval()]).concatenate( + tensor_shape.TensorShape(cauchy.batch_shape_tensor().eval())) + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + expected_shape = (tensor_shape.TensorShape( + [n.eval()]).concatenate(cauchy.batch_shape)) + self.assertAllEqual(expected_shape, samples.shape) + self.assertAllEqual(expected_shape, sample_values.shape) + + def testCauchyNegativeLocFails(self): + with self.test_session(): + cauchy = cauchy_lib.Cauchy(loc=[1.], scale=[-5.], validate_args=True) + with self.assertRaisesOpError("Condition x > 0 did not hold"): + cauchy.mode().eval() + + def testCauchyShape(self): + with self.test_session(): + loc = constant_op.constant([-3.0] * 5) + scale = constant_op.constant(11.0) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + self.assertEqual(cauchy.batch_shape_tensor().eval(), [5]) + self.assertEqual(cauchy.batch_shape, tensor_shape.TensorShape([5])) + self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) + self.assertEqual(cauchy.event_shape, tensor_shape.TensorShape([])) + + def testCauchyShapeWithPlaceholders(self): + loc = array_ops.placeholder(dtype=dtypes.float32) + scale = array_ops.placeholder(dtype=dtypes.float32) + cauchy = cauchy_lib.Cauchy(loc=loc, scale=scale) + + with self.test_session() as sess: + # get_batch_shape should return an "" tensor. + self.assertEqual(cauchy.batch_shape, tensor_shape.TensorShape(None)) + self.assertEqual(cauchy.event_shape, ()) + self.assertAllEqual(cauchy.event_shape_tensor().eval(), []) + self.assertAllEqual( + sess.run(cauchy.batch_shape_tensor(), + feed_dict={loc: 5.0, + scale: [1.0, 2.0]}), [2]) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py index cc7d6fd5ddda8fcdfdf6c8a3f80feeda7a42541e..2d74aa1f320149d0f7ef9e9c52b8c7053c2f74d7 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/distribution_util_test.py @@ -23,11 +23,11 @@ import itertools import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util -from tensorflow.contrib.linalg.python.ops import linear_operator_diag from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops.linalg import linear_operator_diag import tensorflow.python.ops.nn_grad # pylint: disable=unused-import from tensorflow.python.platform import test @@ -287,6 +287,26 @@ class ShapesFromLocAndScaleTest(test.TestCase): self.assertAllEqual([3], event_shape) +class GetBroadcastShapeTest(test.TestCase): + + def test_all_static_shapes_work(self): + x = array_ops.ones((2, 1, 3)) + y = array_ops.ones((1, 5, 3)) + z = array_ops.ones(()) + self.assertAllEqual([2, 5, 3], + distribution_util.get_broadcast_shape(x, y, z)) + + def test_with_some_dynamic_shapes_works(self): + x = array_ops.ones((2, 1, 3)) + y = array_ops.placeholder(x.dtype) + z = array_ops.ones(()) + with self.test_session() as sess: + bcast_shape = sess.run( + distribution_util.get_broadcast_shape(x, y, z), + feed_dict={y: np.ones((1, 5, 3)).astype(np.float32)}) + self.assertAllEqual([2, 5, 3], bcast_shape) + + class TridiagTest(test.TestCase): def testWorksCorrectlyNoBatches(self): @@ -374,5 +394,6 @@ class MixtureStddevTest(test.TestCase): self.assertAllClose(actual_devs, expected_devs) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py new file mode 100644 index 0000000000000000000000000000000000000000..06318ca09dec851cf025fa35c83732b85824cbee --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/independent_test.py @@ -0,0 +1,184 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the Independent distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import importlib +import numpy as np + +from tensorflow.contrib.distributions.python.ops import independent as independent_lib +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test +from tensorflow.python.platform import tf_logging + + +def try_import(name): # pylint: disable=invalid-name + module = None + try: + module = importlib.import_module(name) + except ImportError as e: + tf_logging.warning("Could not import %s: %s" % (name, str(e))) + return module + +stats = try_import("scipy.stats") + + +class ProductDistributionTest(test.TestCase): + + def setUp(self): + self._rng = np.random.RandomState(42) + + def testSampleAndLogProbUnivariate(self): + loc = np.float32([-1., 1]) + scale = np.float32([0.1, 0.5]) + with self.test_session() as sess: + ind = independent_lib.Independent( + distribution=normal_lib.Normal(loc=loc, scale=scale), + reinterpreted_batch_ndims=1) + + x = ind.sample([4, 5], seed=42) + log_prob_x = ind.log_prob(x) + x_, actual_log_prob_x = sess.run([x, log_prob_x]) + + self.assertEqual([], ind.batch_shape) + self.assertEqual([2], ind.event_shape) + self.assertEqual([4, 5, 2], x.shape) + self.assertEqual([4, 5], log_prob_x.shape) + + expected_log_prob_x = stats.norm(loc, scale).logpdf(x_).sum(-1) + self.assertAllClose(expected_log_prob_x, actual_log_prob_x, + rtol=1e-5, atol=0.) + + def testSampleAndLogProbMultivariate(self): + loc = np.float32([[-1., 1], [1, -1]]) + scale = np.float32([1., 0.5]) + with self.test_session() as sess: + ind = independent_lib.Independent( + distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=loc, + scale_identity_multiplier=scale), + reinterpreted_batch_ndims=1) + + x = ind.sample([4, 5], seed=42) + log_prob_x = ind.log_prob(x) + x_, actual_log_prob_x = sess.run([x, log_prob_x]) + + self.assertEqual([], ind.batch_shape) + self.assertEqual([2, 2], ind.event_shape) + self.assertEqual([4, 5, 2, 2], x.shape) + self.assertEqual([4, 5], log_prob_x.shape) + + expected_log_prob_x = stats.norm(loc, scale[:, None]).logpdf( + x_).sum(-1).sum(-1) + self.assertAllClose(expected_log_prob_x, actual_log_prob_x, + rtol=1e-6, atol=0.) + + def testSampleConsistentStats(self): + loc = np.float32([[-1., 1], [1, -1]]) + scale = np.float32([1., 0.5]) + n_samp = 1e4 + with self.test_session() as sess: + ind = independent_lib.Independent( + distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=loc, + scale_identity_multiplier=scale), + reinterpreted_batch_ndims=1) + + x = ind.sample(int(n_samp), seed=42) + sample_mean = math_ops.reduce_mean(x, axis=0) + sample_var = math_ops.reduce_mean( + math_ops.squared_difference(x, sample_mean), axis=0) + sample_std = math_ops.sqrt(sample_var) + sample_entropy = -math_ops.reduce_mean(ind.log_prob(x), axis=0) + + [ + sample_mean_, sample_var_, sample_std_, sample_entropy_, + actual_mean_, actual_var_, actual_std_, actual_entropy_, + actual_mode_, + ] = sess.run([ + sample_mean, sample_var, sample_std, sample_entropy, + ind.mean(), ind.variance(), ind.stddev(), ind.entropy(), ind.mode(), + ]) + + self.assertAllClose(sample_mean_, actual_mean_, rtol=0.02, atol=0.) + self.assertAllClose(sample_var_, actual_var_, rtol=0.04, atol=0.) + self.assertAllClose(sample_std_, actual_std_, rtol=0.02, atol=0.) + self.assertAllClose(sample_entropy_, actual_entropy_, rtol=0.01, atol=0.) + self.assertAllClose(loc, actual_mode_, rtol=1e-6, atol=0.) + + def _testMnistLike(self, static_shape): + sample_shape = [4, 5] + batch_shape = [10] + image_shape = [28, 28, 1] + logits = 3 * self._rng.random_sample( + batch_shape + image_shape).astype(np.float32) - 1 + + def expected_log_prob(x, logits): + return (x * logits - np.log1p(np.exp(logits))).sum(-1).sum(-1).sum(-1) + + with self.test_session() as sess: + logits_ph = array_ops.placeholder( + dtypes.float32, shape=logits.shape if static_shape else None) + ind = independent_lib.Independent( + distribution=bernoulli_lib.Bernoulli(logits=logits_ph)) + x = ind.sample(sample_shape, seed=42) + log_prob_x = ind.log_prob(x) + [ + x_, + actual_log_prob_x, + ind_batch_shape, + ind_event_shape, + x_shape, + log_prob_x_shape, + ] = sess.run([ + x, + log_prob_x, + ind.batch_shape_tensor(), + ind.event_shape_tensor(), + array_ops.shape(x), + array_ops.shape(log_prob_x), + ], feed_dict={logits_ph: logits}) + + if static_shape: + ind_batch_shape = ind.batch_shape + ind_event_shape = ind.event_shape + x_shape = x.shape + log_prob_x_shape = log_prob_x.shape + + self.assertAllEqual(batch_shape, ind_batch_shape) + self.assertAllEqual(image_shape, ind_event_shape) + self.assertAllEqual(sample_shape + batch_shape + image_shape, x_shape) + self.assertAllEqual(sample_shape + batch_shape, log_prob_x_shape) + self.assertAllClose(expected_log_prob(x_, logits), + actual_log_prob_x, + rtol=1e-6, atol=0.) + + def testMnistLikeStaticShape(self): + self._testMnistLike(static_shape=True) + + def testMnistLikeDynamicShape(self): + self._testMnistLike(static_shape=False) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py new file mode 100644 index 0000000000000000000000000000000000000000..ece6bc077d9e21502fdfd01300a9d3e9f2c9c380 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_same_family_test.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for MixtureSameFamily distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.distributions.python.ops import mixture_same_family as mixture_same_family_lib +from tensorflow.contrib.distributions.python.ops import mvn_diag as mvn_diag_lib +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli as bernoulli_lib +from tensorflow.python.ops.distributions import categorical as categorical_lib +from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.platform import test + + +class MixtureSameFamilyTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + def testSampleAndLogProbUnivariateShapes(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=normal_lib.Normal( + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5], x.shape) + self.assertEqual([4, 5], log_prob_x.shape) + + def testSampleAndLogProbShapesBroadcastMix(self): + mix_probs = np.float32([.3, .7]) + bern_probs = np.float32([[.4, .6], [.25, .75]]) + with self.test_session(): + bm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=mix_probs), + components_distribution=bernoulli_lib.Bernoulli(probs=bern_probs)) + x = bm.sample([4, 5], seed=42) + log_prob_x = bm.log_prob(x) + x_ = x.eval() + self.assertEqual([4, 5, 2], x.shape) + self.assertEqual([4, 5, 2], log_prob_x.shape) + self.assertAllEqual( + np.ones_like(x_, dtype=np.bool), np.logical_or(x_ == 0., x_ == 1.)) + + def testSampleAndLogProbMultivariateShapes(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 2], x.shape) + self.assertEqual([4, 5], log_prob_x.shape) + + def testSampleAndLogProbBatchMultivariateShapes(self): + with self.test_session(): + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[[-1., 1], + [1, -1]], + [[0., 1], + [1, 0]]], + scale_identity_multiplier=[1., 0.5])) + x = gm.sample([4, 5], seed=42) + log_prob_x = gm.log_prob(x) + self.assertEqual([4, 5, 2, 2], x.shape) + self.assertEqual([4, 5, 2], log_prob_x.shape) + + def testSampleConsistentLogProb(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + # Ball centered at component0's mean. + self.run_test_sample_consistent_log_prob( + sess.run, gm, radius=1., center=[-1., 1], rtol=0.02) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess.run, gm, radius=1., center=[1., -1], rtol=0.02) + + def testLogCdf(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=normal_lib.Normal( + loc=[-1., 1], scale=[0.1, 0.5])) + x = gm.sample(10, seed=42) + actual_log_cdf = gm.log_cdf(x) + expected_log_cdf = math_ops.reduce_logsumexp( + (gm.mixture_distribution.logits + + gm.components_distribution.log_cdf(x[..., array_ops.newaxis])), + axis=1) + actual_log_cdf_, expected_log_cdf_ = sess.run([ + actual_log_cdf, expected_log_cdf]) + self.assertAllClose(actual_log_cdf_, expected_log_cdf_, + rtol=1e-6, atol=0.0) + + def testSampleConsistentMeanCovariance(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + self.run_test_sample_consistent_mean_covariance(sess.run, gm) + + def testVarianceConsistentCovariance(self): + with self.test_session() as sess: + gm = mixture_same_family_lib.MixtureSameFamily( + mixture_distribution=categorical_lib.Categorical(probs=[0.3, 0.7]), + components_distribution=mvn_diag_lib.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], scale_identity_multiplier=[1., 0.5])) + cov_, var_ = sess.run([gm.covariance(), gm.variance()]) + self.assertAllClose(cov_.diagonal(), var_, atol=0.) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py index bd8f405e5b4f26466886791dbb0a6ee1eea0e888..1e514fe0ff21cd53c8c235da417890773db50c37 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py @@ -38,7 +38,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import test from tensorflow.python.platform import tf_logging as logging -distributions_py = distributions +ds = distributions def _swap_first_last_axes(array): @@ -71,35 +71,40 @@ def _mixture_stddev_np(pi_vector, mu_vector, sigma_vector): @contextlib.contextmanager def _test_capture_mvndiag_sample_outputs(): - """Use monkey-patching to capture the output of an MVNDiag _sample_n.""" + """Use monkey-patching to capture the output of an MVNDiag _call_sample_n.""" data_container = [] - true_mvndiag_sample_n = distributions_py.MultivariateNormalDiag._sample_n + true_mvndiag_call_sample_n = ( + ds.MultivariateNormalDiag._call_sample_n) - def _capturing_mvndiag_sample_n(self, n, seed=None): - samples = true_mvndiag_sample_n(self, n=n, seed=seed) + def _capturing_mvndiag_call_sample_n( + self, sample_shape, seed, name, **kwargs): + samples = true_mvndiag_call_sample_n( + self, sample_shape, seed, name, **kwargs) data_container.append(samples) return samples - distributions_py.MultivariateNormalDiag._sample_n = ( - _capturing_mvndiag_sample_n) + ds.MultivariateNormalDiag._call_sample_n = ( + _capturing_mvndiag_call_sample_n) yield data_container - distributions_py.MultivariateNormalDiag._sample_n = true_mvndiag_sample_n + ds.MultivariateNormalDiag._call_sample_n = ( + true_mvndiag_call_sample_n) @contextlib.contextmanager def _test_capture_normal_sample_outputs(): - """Use monkey-patching to capture the output of an Normal _sample_n.""" + """Use monkey-patching to capture the output of an Normal _call_sample_n.""" data_container = [] - true_normal_sample_n = distributions_py.Normal._sample_n + true_normal_call_sample_n = ds.Normal._call_sample_n - def _capturing_normal_sample_n(self, n, seed=None): - samples = true_normal_sample_n(self, n=n, seed=seed) + def _capturing_normal_call_sample_n(self, sample_shape, seed, name, **kwargs): + samples = true_normal_call_sample_n( + self, sample_shape, seed, name, **kwargs) data_container.append(samples) return samples - distributions_py.Normal._sample_n = _capturing_normal_sample_n + ds.Normal._call_sample_n = _capturing_normal_call_sample_n yield data_container - distributions_py.Normal._sample_n = true_normal_sample_n + ds.Normal._call_sample_n = true_normal_call_sample_n def make_univariate_mixture(batch_shape, num_components): @@ -108,13 +113,13 @@ def make_univariate_mixture(batch_shape, num_components): array_ops.concat((batch_shape, [num_components]), axis=0), -1, 1, dtype=dtypes.float32) - 50. components = [ - distributions_py.Normal( + ds.Normal( loc=random_ops.random_normal(batch_shape), scale=10 * random_ops.random_uniform(batch_shape)) for _ in range(num_components) ] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) def make_multivariate_mixture(batch_shape, num_components, event_shape, @@ -136,11 +141,11 @@ def make_multivariate_mixture(batch_shape, num_components, event_shape, scale_diag = 10 * random_ops.random_uniform(batch_and_event_shape) loc.set_shape(static_batch_and_event_shape) scale_diag.set_shape(static_batch_and_event_shape) - return distributions_py.MultivariateNormalDiag( + return ds.MultivariateNormalDiag( loc=loc, scale_diag=scale_diag) components = [create_component() for _ in range(num_components)] - cat = distributions_py.Categorical(logits, dtype=dtypes.int32) - return distributions_py.Mixture(cat, components) + cat = ds.Categorical(logits, dtype=dtypes.int32) + return ds.Mixture(cat, components) class MixtureTest(test.TestCase): @@ -165,37 +170,37 @@ class MixtureTest(test.TestCase): def testBrokenShapesStatic(self): with self.assertRaisesWithPredicateMatch(ValueError, r"cat.num_classes != len"): - distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.5]), # 2 classes - [distributions_py.Normal(loc=1.0, scale=2.0)]) + ds.Mixture( + ds.Categorical([0.1, 0.5]), # 2 classes + [ds.Normal(loc=1.0, scale=2.0)]) with self.assertRaisesWithPredicateMatch( ValueError, r"\(\) and \(2,\) are not compatible"): # The value error is raised because the batch shapes of the # Normals are not equal. One is a scalar, the other is a # vector of size (2,). - distributions_py.Mixture( - distributions_py.Categorical([-0.5, 0.5]), # scalar batch + ds.Mixture( + ds.Categorical([-0.5, 0.5]), # scalar batch [ - distributions_py.Normal( + ds.Normal( loc=1.0, scale=2.0), # scalar dist - distributions_py.Normal( + ds.Normal( loc=[1.0, 1.0], scale=[2.0, 2.0]) ]) with self.assertRaisesWithPredicateMatch(ValueError, r"Could not infer"): cat_logits = array_ops.placeholder(shape=[1, None], dtype=dtypes.float32) - distributions_py.Mixture( - distributions_py.Categorical(cat_logits), - [distributions_py.Normal( + ds.Mixture( + ds.Categorical(cat_logits), + [ds.Normal( loc=[1.0], scale=[2.0])]) def testBrokenShapesDynamic(self): with self.test_session(): d0_param = array_ops.placeholder(dtype=dtypes.float32) d1_param = array_ops.placeholder(dtype=dtypes.float32) - d = distributions_py.Mixture( - distributions_py.Categorical([0.1, 0.2]), [ - distributions_py.Normal( - loc=d0_param, scale=d0_param), distributions_py.Normal( + d = ds.Mixture( + ds.Categorical([0.1, 0.2]), [ + ds.Normal( + loc=d0_param, scale=d0_param), ds.Normal( loc=d1_param, scale=d1_param) ], validate_args=True) @@ -206,21 +211,21 @@ class MixtureTest(test.TestCase): def testBrokenTypes(self): with self.assertRaisesWithPredicateMatch(TypeError, "Categorical"): - distributions_py.Mixture(None, []) - cat = distributions_py.Categorical([0.3, 0.2]) + ds.Mixture(None, []) + cat = ds.Categorical([0.3, 0.2]) # components must be a list of distributions with self.assertRaisesWithPredicateMatch( TypeError, "all .* must be Distribution instances"): - distributions_py.Mixture(cat, [None]) + ds.Mixture(cat, [None]) with self.assertRaisesWithPredicateMatch(TypeError, "same dtype"): - distributions_py.Mixture( + ds.Mixture( cat, [ - distributions_py.Normal(loc=[1.0], scale=[2.0]), - distributions_py.Normal(loc=[np.float16(1.0)], - scale=[np.float16(2.0)]), + ds.Normal(loc=[1.0], scale=[2.0]), + ds.Normal(loc=[np.float16(1.0)], + scale=[np.float16(2.0)]), ]) with self.assertRaisesWithPredicateMatch(ValueError, "non-empty list"): - distributions_py.Mixture(distributions_py.Categorical([0.3, 0.2]), None) + ds.Mixture(ds.Categorical([0.3, 0.2]), None) # TODO(ebrevdo): once distribution Domains have been added, add a # test to ensure that the domains of the distributions in a @@ -359,13 +364,13 @@ class MixtureTest(test.TestCase): component_devs = np.array([0.05, 2.33]) ground_truth_stddev = 5.3120805 - mixture_dist = distributions_py.Mixture( - cat=distributions_py.Categorical(probs=cat_probs), + mixture_dist = ds.Mixture( + cat=ds.Categorical(probs=cat_probs), components=[ - distributions_py.Normal(loc=component_means[0], - scale=component_devs[0]), - distributions_py.Normal(loc=component_means[1], - scale=component_devs[1]), + ds.Normal(loc=component_means[0], + scale=component_devs[0]), + ds.Normal(loc=component_means[1], + scale=component_devs[1]), ]) mix_dev = mixture_dist.stddev() with self.test_session() as sess: @@ -512,22 +517,22 @@ class MixtureTest(test.TestCase): random_seed.set_random_seed(654321) components = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat = distributions_py.Categorical( + cat = ds.Categorical( logits, dtype=dtypes.int32, name="cat1") - dist1 = distributions_py.Mixture(cat, components, name="mixture1") + dist1 = ds.Mixture(cat, components, name="mixture1") samples1 = dist1.sample(n, seed=123456).eval() random_seed.set_random_seed(654321) components2 = [ - distributions_py.Normal( + ds.Normal( loc=mu, scale=sigma) for mu, sigma in zip(mus, sigmas) ] - cat2 = distributions_py.Categorical( + cat2 = ds.Categorical( logits, dtype=dtypes.int32, name="cat2") - dist2 = distributions_py.Mixture(cat2, components2, name="mixture2") + dist2 = ds.Mixture(cat2, components2, name="mixture2") samples2 = dist2.sample(n, seed=123456).eval() self.assertAllClose(samples1, samples2) @@ -660,15 +665,15 @@ class MixtureTest(test.TestCase): e_x = np.exp(x - np.max(x)) return e_x / e_x.sum() - # Construct the distributions_py.Mixture object. + # Construct the ds.Mixture object. mixture_weights = _scalar_univariate_softmax(mixture_weight_logits) means = [np.random.uniform(low=-10, high=10, size=()).astype(np.float32) for _ in range(n_components)] sigmas = [np.ones(shape=(), dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=(), dtype=dtypes.float32) @@ -713,10 +718,10 @@ class MixtureTest(test.TestCase): for _ in range(n_components)] sigmas = [np.ones(shape=psize, dtype=np.float32) for _ in range(n_components)] - cat_tf = distributions_py.Categorical(probs=mixture_weights) - components_tf = [distributions_py.Normal(loc=mu, scale=sigma) + cat_tf = ds.Categorical(probs=mixture_weights) + components_tf = [ds.Normal(loc=mu, scale=sigma) for (mu, sigma) in zip(means, sigmas)] - mixture_tf = distributions_py.Mixture(cat=cat_tf, components=components_tf) + mixture_tf = ds.Mixture(cat=cat_tf, components=components_tf) x_tensor = array_ops.placeholder(shape=psize, dtype=dtypes.float32) xs_to_check = [ @@ -745,6 +750,20 @@ class MixtureTest(test.TestCase): self.assertAllClose(x_cdf_tf_result, scipy_cdf_result) self.assertAllClose(np.exp(x_log_cdf_tf_result), scipy_cdf_result) + def testSampleBimixGamma(self): + """Tests a bug in the underlying tf.Gamma op. + + Mixture's use of dynamic partition requires `random_gamma` correctly returns + an empty `Tensor`. + """ + with self.test_session(): + gm = ds.Mixture( + cat=ds.Categorical(probs=[.3, .7]), + components=[ds.Gamma(1., 2.), + ds.Gamma(2., 1.)]) + x_ = gm.sample().eval() + self.assertAllEqual([], x_.shape) + class MixtureBenchmark(test.Benchmark): @@ -779,7 +798,7 @@ class MixtureBenchmark(test.Benchmark): 2, "mvn_diag\tuse_gpu\tcomponents\tbatch\tfeatures\tsample\twall_time") def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -790,9 +809,9 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalDiag( + ds.MultivariateNormalDiag( loc=mu, scale_diag=sigma) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): @@ -819,7 +838,7 @@ class MixtureBenchmark(test.Benchmark): return np.stack([np.dot(np.transpose(z), z) for z in x]) def create_distribution(batch_size, num_components, num_features): - cat = distributions_py.Categorical( + cat = ds.Categorical( logits=np.random.randn(batch_size, num_components)) mus = [ variables.Variable(np.random.randn(batch_size, num_features)) @@ -831,10 +850,10 @@ class MixtureBenchmark(test.Benchmark): for _ in range(num_components) ] components = list( - distributions_py.MultivariateNormalTriL( + ds.MultivariateNormalTriL( loc=mu, scale_tril=linalg_ops.cholesky(sigma)) for (mu, sigma) in zip(mus, sigmas)) - return distributions_py.Mixture(cat, components) + return ds.Mixture(cat, components) for use_gpu in False, True: if use_gpu and not test.is_gpu_available(): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py index 43e302475b49ef5245ba324c35ca294b51a566b6..933756aa8e12cca4c42eb98d9193512bbf2ad585 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py @@ -289,6 +289,18 @@ class MultivariateNormalDiagTest(test.TestCase): self.assertListEqual(mvn.batch_shape.as_list(), [2, 3]) self.assertListEqual(mvn.event_shape.as_list(), [None]) + def testKLDivIdenticalGradientDefined(self): + dims = 3 + with self.test_session() as sess: + loc = array_ops.zeros([dims], dtype=dtypes.float32) + mvn = ds.MultivariateNormalDiag( + loc=loc, + scale_diag=np.ones([dims], dtype=np.float32)) + g = gradients_impl.gradients(ds.kl_divergence(mvn, mvn), loc) + g_ = sess.run(g) + self.assertAllEqual(np.ones_like(g_, dtype=np.bool), + np.isfinite(g_)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py index c1a74c6483b9843c609ac94054a8c27476f7d7ff..37edaa42cdc202cda4aa173752a3639792f96daf 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py @@ -241,6 +241,28 @@ class NegativeBinomialTest(test.TestCase): atol=0., rtol=.02) + def testLogProbOverflow(self): + with self.test_session() as sess: + logits = np.float32([20., 30., 40.]) + total_count = np.float32(1.) + x = np.float32(0.) + nb = negative_binomial.NegativeBinomial( + total_count=total_count, logits=logits) + log_prob_ = sess.run(nb.log_prob(x)) + self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool), + np.isfinite(log_prob_)) + + def testLogProbUnderflow(self): + with self.test_session() as sess: + logits = np.float32([-90, -100, -110]) + total_count = np.float32(1.) + x = np.float32(0.) + nb = negative_binomial.NegativeBinomial( + total_count=total_count, logits=logits) + log_prob_ = sess.run(nb.log_prob(x)) + self.assertAllEqual(np.ones_like(log_prob_, dtype=np.bool), + np.isfinite(log_prob_)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py index 7cb46bb2367658518c98baaa14947b5ad837ff12..3c0147b8cf6e1b6a2791e85c0c0997992445fa7e 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_lognormal_test.py @@ -18,8 +18,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import poisson_lognormal from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -32,60 +36,80 @@ class PoissonLogNormalQuadratureCompoundTest( pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=-2., scale=1.1, - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1) + sess.run, pln, rtol=0.1) def testMeanVariance(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=0., scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.02) + sess.run, pln, rtol=0.02) def testSampleProbConsistentBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testMeanVarianceBroadcastScalar(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) def testSampleProbConsistentBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_log_prob( - sess, pln, rtol=0.1, atol=0.08) + sess.run, pln, rtol=0.1, atol=0.08) def testMeanVarianceBroadcastBoth(self): with self.test_session() as sess: pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( loc=[[0.], [-0.5]], scale=[[1., 0.9]], - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) self.run_test_sample_consistent_mean_variance( - sess, pln, rtol=0.1, atol=0.01) + sess.run, pln, rtol=0.1, atol=0.01) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=10) + pln = poisson_lognormal.PoissonLogNormalQuadratureCompound( + loc=-2., + scale=1.1, + quadrature_grid_and_probs=(g, p), + validate_args=True) + self.run_test_sample_consistent_log_prob( + lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}), + pln, rtol=0.1) if __name__ == "__main__": diff --git a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py index f157c0d3edd6e56083b7914d89dcd1e5b9420f78..d9c9008417cdb20b62390630cf887d3bd888a0d3 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py @@ -24,15 +24,19 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops from tensorflow.python.platform import test class PoissonTest(test.TestCase): + def _make_poisson(self, rate, validate_args=False): + return poisson_lib.Poisson(rate=rate, validate_args=validate_args) + def testPoissonShape(self): with self.test_session(): lam = constant_op.constant([3.0] * 5) - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) self.assertEqual(poisson.batch_shape_tensor().eval(), (5,)) self.assertEqual(poisson.batch_shape, tensor_shape.TensorShape([5])) @@ -40,11 +44,11 @@ class PoissonTest(test.TestCase): self.assertEqual(poisson.event_shape, tensor_shape.TensorShape([])) def testInvalidLam(self): - invalid_lams = [-.01, 0, -2.] + invalid_lams = [-.01, 0., -2.] for lam in invalid_lams: with self.test_session(): with self.assertRaisesOpError("Condition x > 0"): - poisson = poisson_lib.Poisson(rate=lam, validate_args=True) + poisson = self._make_poisson(rate=lam, validate_args=True) poisson.rate.eval() def testPoissonLogPmf(self): @@ -53,7 +57,7 @@ class PoissonTest(test.TestCase): lam = constant_op.constant([3.0] * batch_size) lam_v = 3.0 x = [2., 3., 4., 5., 6., 7.] - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6,)) self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v)) @@ -68,7 +72,7 @@ class PoissonTest(test.TestCase): lam = constant_op.constant([3.0] * batch_size) x = array_ops.placeholder(dtypes.float32, shape=[6]) feed_dict = {x: [2.5, 3.2, 4.3, 5.1, 6., 7.]} - poisson = poisson_lib.Poisson(rate=lam, validate_args=True) + poisson = self._make_poisson(rate=lam, validate_args=True) # Non-integer with self.assertRaisesOpError("cannot contain fractional components"): @@ -79,7 +83,7 @@ class PoissonTest(test.TestCase): log_pmf = poisson.log_prob([-1.]) log_pmf.eval(feed_dict=feed_dict) - poisson = poisson_lib.Poisson(rate=lam, validate_args=False) + poisson = self._make_poisson(rate=lam, validate_args=False) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6,)) pmf = poisson.prob(x) @@ -92,7 +96,7 @@ class PoissonTest(test.TestCase): lam_v = [2.0, 4.0, 5.0] x = np.array([[2., 3., 4., 5., 6., 7.]], dtype=np.float32).T - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) log_pmf = poisson.log_prob(x) self.assertEqual(log_pmf.get_shape(), (6, 3)) self.assertAllClose(log_pmf.eval(), stats.poisson.logpmf(x, lam_v)) @@ -108,7 +112,7 @@ class PoissonTest(test.TestCase): lam_v = 3.0 x = [2.2, 3.1, 4., 5.5, 6., 7.] - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) self.assertEqual(log_cdf.get_shape(), (6,)) self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v)) @@ -124,7 +128,7 @@ class PoissonTest(test.TestCase): lam_v = [2.0, 4.0, 5.0] x = np.array([[2.2, 3.1, 4., 5.5, 6., 7.]], dtype=np.float32).T - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) log_cdf = poisson.log_cdf(x) self.assertEqual(log_cdf.get_shape(), (6, 3)) self.assertAllClose(log_cdf.eval(), stats.poisson.logcdf(x, lam_v)) @@ -136,7 +140,7 @@ class PoissonTest(test.TestCase): def testPoissonMean(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) self.assertEqual(poisson.mean().get_shape(), (3,)) self.assertAllClose(poisson.mean().eval(), stats.poisson.mean(lam_v)) self.assertAllClose(poisson.mean().eval(), lam_v) @@ -144,7 +148,7 @@ class PoissonTest(test.TestCase): def testPoissonVariance(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) self.assertEqual(poisson.variance().get_shape(), (3,)) self.assertAllClose(poisson.variance().eval(), stats.poisson.var(lam_v)) self.assertAllClose(poisson.variance().eval(), lam_v) @@ -152,7 +156,7 @@ class PoissonTest(test.TestCase): def testPoissonStd(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5] - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) self.assertEqual(poisson.stddev().get_shape(), (3,)) self.assertAllClose(poisson.stddev().eval(), stats.poisson.std(lam_v)) self.assertAllClose(poisson.stddev().eval(), np.sqrt(lam_v)) @@ -160,14 +164,14 @@ class PoissonTest(test.TestCase): def testPoissonMode(self): with self.test_session(): lam_v = [1.0, 3.0, 2.5, 3.2, 1.1, 0.05] - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) self.assertEqual(poisson.mode().get_shape(), (6,)) self.assertAllClose(poisson.mode().eval(), np.floor(lam_v)) def testPoissonMultipleMode(self): with self.test_session(): lam_v = [1.0, 3.0, 2.0, 4.0, 5.0, 10.0] - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) # For the case where lam is an integer, the modes are: lam and lam - 1. # In this case, we get back the larger of the two modes. self.assertEqual((6,), poisson.mode().get_shape()) @@ -180,7 +184,7 @@ class PoissonTest(test.TestCase): # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be # within `k` std. deviations of actual up to rtol precision. n = int(100e3) - poisson = poisson_lib.Poisson(rate=lam) + poisson = self._make_poisson(rate=lam) samples = poisson.sample(n, seed=123456) sample_values = samples.eval() self.assertEqual(samples.get_shape(), (n,)) @@ -193,7 +197,7 @@ class PoissonTest(test.TestCase): def testPoissonSampleMultidimensionalMean(self): with self.test_session(): lam_v = np.array([np.arange(1, 51, dtype=np.float32)]) # 1 x 50 - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) # Choosing `n >= (k/rtol)**2, roughly ensures our sample mean should be # within `k` std. deviations of actual up to rtol precision. n = int(100e3) @@ -210,7 +214,7 @@ class PoissonTest(test.TestCase): def testPoissonSampleMultidimensionalVariance(self): with self.test_session(): lam_v = np.array([np.arange(5, 15, dtype=np.float32)]) # 1 x 10 - poisson = poisson_lib.Poisson(rate=lam_v) + poisson = self._make_poisson(rate=lam_v) # Choosing `n >= 2 * lam * (k/rtol)**2, roughly ensures our sample # variance should be within `k` std. deviations of actual up to rtol # precision. @@ -224,5 +228,18 @@ class PoissonTest(test.TestCase): sample_values.var(axis=0), stats.poisson.var(lam_v), rtol=.03, atol=0) +class PoissonLogRateTest(PoissonTest): + + def _make_poisson(self, rate, validate_args=False): + return poisson_lib.Poisson( + log_rate=math_ops.log(rate), + validate_args=validate_args) + + def testInvalidLam(self): + # No need to worry about the non-negativity of `rate` when using the + # `log_rate` parameterization. + pass + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py index 8c8363fe3f5159ed4def82472df8cb8ff518b05c..faae9da6ad812c629a2bdbb985fdd6f78a0860e1 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py @@ -164,6 +164,14 @@ class RelaxedOneHotCategoricalTest(test.TestCase): self.assertAllEqual([5, 3], dist.sample(5).eval(feed_dict=feed_dict).shape) + def testDTypes(self): + # check that sampling and log_prob work for a range of dtypes + with self.test_session(): + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + logits = random_ops.random_uniform(shape=[3, 3], dtype=dtype) + dist = relaxed_onehot_categorical.RelaxedOneHotCategorical( + temperature=0.5, logits=logits) + dist.log_prob(dist.sample()) if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py new file mode 100644 index 0000000000000000000000000000000000000000..88b48736dd55270fb4e149ae1560911179e446e9 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/sinh_arcsinh_test.py @@ -0,0 +1,221 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SinhArcsinh.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib import distributions +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + +ds = distributions +rng = np.random.RandomState(123) + + +class SinhArcsinhTest(test.TestCase): + + def test_default_is_same_as_normal(self): + b = 10 + scale = rng.rand(b) + 0.5 + loc = rng.randn(b) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + validate_args=True) + + x = rng.randn(5, b) + norm_pdf, sasnorm_pdf = sess.run([norm.prob(x), sasnorm.prob(x)]) + self.assertAllClose(norm_pdf, sasnorm_pdf) + + norm_samps, sasnorm_samps = sess.run( + [norm.sample(10000, seed=0), + sasnorm.sample(10000, seed=0)]) + self.assertAllClose(loc, sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.mean(axis=0), sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.std(axis=0), sasnorm_samps.std(axis=0), atol=0.1) + + def test_broadcast_params_dynamic(self): + with self.test_session() as sess: + loc = array_ops.placeholder(dtypes.float64) + scale = array_ops.placeholder(dtypes.float64) + skewness = array_ops.placeholder(dtypes.float64) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + skewness=skewness, + validate_args=True) + + samp = sess.run(sasnorm.sample(), + feed_dict={loc: rng.rand(5), + scale: np.float64(rng.rand()), # Scalar + skewness: rng.rand(5)}) + self.assertAllEqual((5,), samp.shape) + + def test_passing_in_laplace_plus_defaults_is_same_as_laplace(self): + b = 10 + scale = rng.rand(b) + 0.5 + loc = rng.randn(b) + with self.test_session() as sess: + lap = ds.Laplace( + loc=loc, + scale=scale, + validate_args=True) + saslap = ds.SinhArcsinh( + loc=loc, + scale=scale, + distribution=ds.Laplace(np.float64(0), np.float64(1)), + validate_args=True) + + x = rng.randn(5, b) + lap_pdf, saslap_pdf = sess.run([lap.prob(x), saslap.prob(x)]) + self.assertAllClose(lap_pdf, saslap_pdf) + + lap_samps, saslap_samps = sess.run( + [lap.sample(10000, seed=0), + saslap.sample(10000, seed=0)]) + self.assertAllClose(loc, saslap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + lap_samps.mean(axis=0), saslap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + lap_samps.std(axis=0), saslap_samps.std(axis=0), atol=0.1) + + def test_tailweight_small_gives_fewer_outliers_than_normal(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = 0.1 * rng.randn(batch_size) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + tailweight=0.1, + validate_args=True) + + # sasnorm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * batch_size, [10] * batch_size]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(sasnorm_lp, norm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the normal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=1), + sasnorm.sample(int(5e5), seed=1)]) + np.testing.assert_array_less( + np.percentile(norm_samps, 0.1, axis=0), + np.percentile(sasnorm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 99.9, axis=0), + np.percentile(norm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_tailweight_large_gives_more_outliers_than_normal(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = np.float64(0.) + with self.test_session() as sess: + norm = ds.Normal( + loc=loc, + scale=scale, + validate_args=True) + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + tailweight=3., + validate_args=True) + + # norm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * batch_size, [10] * batch_size]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(norm_lp, sasnorm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the sasnormal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=2), + sasnorm.sample(int(5e5), seed=2)]) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 0.1, axis=0), + np.percentile(norm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(norm_samps, 99.9, axis=0), + np.percentile(sasnorm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_positive_skewness_moves_mean_to_the_right(self): + batch_size = 10 + scale = rng.rand(batch_size) + 0.5 + loc = rng.randn(batch_size) + with self.test_session() as sess: + sasnorm = ds.SinhArcsinh( + loc=loc, + scale=scale, + skewness=3.0, + validate_args=True) + + sasnorm_samps = sess.run(sasnorm.sample(10000, seed=4)) + np.testing.assert_array_less(loc, sasnorm_samps.mean(axis=0)) + + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.SinhArcsinh( + loc=0., + scale=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py index 4e0deb83aa90a45d1ba6344c64064074f58e368f..103d8e186221e879d1734a097114708429f725bd 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py @@ -41,6 +41,11 @@ class TransformedDistributionTest(test.TestCase): def _cls(self): return ds.TransformedDistribution + def _make_unimplemented(self, name): + def _unimplemented(self, *args): # pylint: disable=unused-argument + raise NotImplementedError("{} not implemented".format(name)) + return _unimplemented + def testTransformedDistribution(self): g = ops.Graph() with g.as_default(): @@ -75,20 +80,105 @@ class TransformedDistributionTest(test.TestCase): with self.test_session(graph=g): self.assertAllClose(expected, actual.eval(), atol=0, rtol=0.01) - def testCachedSamplesWithoutInverse(self): + def testNonInjectiveTransformedDistribution(self): + g = ops.Graph() + with g.as_default(): + mu = 1. + sigma = 2.0 + abs_normal = self._cls()( + distribution=ds.Normal(loc=mu, scale=sigma), + bijector=bs.AbsoluteValue(event_ndims=0)) + sp_normal = stats.norm(mu, sigma) + + # sample + sample = abs_normal.sample(100000, seed=235) + self.assertAllEqual([], abs_normal.event_shape) + with self.test_session(graph=g): + sample_ = sample.eval() + self.assertAllEqual([], abs_normal.event_shape_tensor().eval()) + + # Abs > 0, duh! + np.testing.assert_array_less(0, sample_) + + # Let X ~ Normal(mu, sigma), Y := |X|, then + # P[Y < 0.77] = P[-0.77 < X < 0.77] + self.assertAllClose( + sp_normal.cdf(0.77) - sp_normal.cdf(-0.77), + (sample_ < 0.77).mean(), rtol=0.01) + + # p_Y(y) = p_X(-y) + p_X(y), + self.assertAllClose( + sp_normal.pdf(1.13) + sp_normal.pdf(-1.13), + abs_normal.prob(1.13).eval()) + + # Log[p_Y(y)] = Log[p_X(-y) + p_X(y)] + self.assertAllClose( + np.log(sp_normal.pdf(2.13) + sp_normal.pdf(-2.13)), + abs_normal.log_prob(2.13).eval()) + + def testQuantile(self): + with self.test_session() as sess: + logit_normal = self._cls()( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.Sigmoid(), + validate_args=True) + grid = [0., 0.25, 0.5, 0.75, 1.] + q = logit_normal.quantile(grid) + cdf = logit_normal.cdf(q) + cdf_ = sess.run(cdf) + self.assertAllClose(grid, cdf_, rtol=1e-6, atol=0.) + + def testCachedSamples(self): + exp_forward_only = bs.Exp(event_ndims=0) + exp_forward_only._inverse = self._make_unimplemented( + "inverse") + exp_forward_only._inverse_event_shape_tensor = self._make_unimplemented( + "inverse_event_shape_tensor ") + exp_forward_only._inverse_event_shape = self._make_unimplemented( + "inverse_event_shape ") + exp_forward_only._inverse_log_det_jacobian = self._make_unimplemented( + "inverse_log_det_jacobian ") + with self.test_session() as sess: mu = 3.0 sigma = 0.02 log_normal = self._cls()( distribution=ds.Normal(loc=mu, scale=sigma), - bijector=bs.Exp(event_ndims=0)) + bijector=exp_forward_only) - sample = log_normal.sample(1) + sample = log_normal.sample([2, 3], seed=42) sample_val, log_pdf_val = sess.run([sample, log_normal.log_prob(sample)]) - self.assertAllClose( - stats.lognorm.logpdf(sample_val, s=sigma, scale=np.exp(mu)), - log_pdf_val, - atol=1e-2) + expected_log_pdf = stats.lognorm.logpdf( + sample_val, s=sigma, scale=np.exp(mu)) + self.assertAllClose(expected_log_pdf, log_pdf_val, rtol=1e-4, atol=0.) + + def testCachedSamplesInvert(self): + exp_inverse_only = bs.Exp(event_ndims=0) + exp_inverse_only._forward = self._make_unimplemented( + "forward") + exp_inverse_only._forward_event_shape_tensor = self._make_unimplemented( + "forward_event_shape_tensor ") + exp_inverse_only._forward_event_shape = self._make_unimplemented( + "forward_event_shape ") + exp_inverse_only._forward_log_det_jacobian = self._make_unimplemented( + "forward_log_det_jacobian ") + + log_forward_only = bs.Invert(exp_inverse_only) + + with self.test_session() as sess: + # The log bijector isn't defined over the whole real line, so we make + # sigma sufficiently small so that the draws are positive. + mu = 2. + sigma = 1e-2 + exp_normal = self._cls()( + distribution=ds.Normal(loc=mu, scale=sigma), + bijector=log_forward_only) + + sample = exp_normal.sample([2, 3], seed=42) + sample_val, log_pdf_val = sess.run([sample, exp_normal.log_prob(sample)]) + expected_log_pdf = sample_val + stats.norm.logpdf( + np.exp(sample_val), loc=mu, scale=sigma) + self.assertAllClose(expected_log_pdf, log_pdf_val, atol=0.) def testShapeChangingBijector(self): with self.test_session(): @@ -130,6 +220,19 @@ class TransformedDistributionTest(test.TestCase): self.assertAllClose(actual_mvn_entropy, fake_mvn.entropy().eval()) + def testScalarBatchScalarEventIdentityScale(self): + with self.test_session() as sess: + exp2 = self._cls()( + ds.Exponential(rate=0.25), + bijector=ds.bijectors.Affine( + scale_identity_multiplier=2., + event_ndims=0)) + log_prob = exp2.log_prob(1.) + log_prob_ = sess.run(log_prob) + base_log_prob = -0.5 * 0.25 + np.log(0.25) + ildj = np.log(2.) + self.assertAllClose(base_log_prob - ildj, log_prob_, rtol=1e-6, atol=0.) + class ScalarToMultiTest(test.TestCase): diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py index 070ee61be314905239e11e8ed3b39f6ffa7510a7..de4a221f7badca8267a81d612a57137c676ff052 100644 --- a/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_diffeomixture_test.py @@ -22,9 +22,11 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import test_util from tensorflow.contrib.distributions.python.ops import vector_diffeomixture as vector_diffeomixture_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops from tensorflow.python.ops.distributions import normal as normal_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib from tensorflow.python.platform import test @@ -55,10 +57,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testSampleProbConsistentBroadcastMixNonStandardBase(self): with self.test_session() as sess: @@ -83,10 +85,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=1., rtol=0.006) + sess.run, vdm, radius=2., center=1., rtol=0.006) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=3., rtol=0.009) + sess.run, vdm, radius=4., center=3., rtol=0.009) def testSampleProbConsistentBroadcastMixBatch(self): with self.test_session() as sess: @@ -114,10 +116,10 @@ class VectorDiffeomixtureTest( validate_args=True) # Ball centered at component0's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=2., center=0., rtol=0.005) + sess.run, vdm, radius=2., center=0., rtol=0.005) # Larger ball centered at component1's mean. self.run_test_sample_consistent_log_prob( - sess, vdm, radius=4., center=2., rtol=0.005) + sess.run, vdm, radius=4., center=2., rtol=0.005) def testMeanCovarianceNoBatch(self): with self.test_session() as sess: @@ -141,7 +143,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) def testMeanCovarianceNoBatchUncenteredNonStandardBase(self): with self.test_session() as sess: @@ -165,7 +167,7 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) + sess.run, vdm, num_samples=int(1e6), rtol=0.01, cov_atol=0.025) def testMeanCovarianceBatch(self): with self.test_session() as sess: @@ -192,7 +194,40 @@ class VectorDiffeomixtureTest( ], validate_args=True) self.run_test_sample_consistent_mean_covariance( - sess, vdm, rtol=0.02, cov_rtol=0.06) + sess.run, vdm, rtol=0.02, cov_rtol=0.06) + + def testSampleProbConsistentDynamicQuadrature(self): + with self.test_session() as sess: + qgrid = array_ops.placeholder(dtype=dtypes.float32) + qprobs = array_ops.placeholder(dtype=dtypes.float32) + g, p = np.polynomial.hermite.hermgauss(deg=8) + dims = 4 + vdm = vector_diffeomixture_lib.VectorDiffeomixture( + mix_loc=[[0.], [1.]], + mix_scale=[1.], + distribution=normal_lib.Normal(0., 1.), + loc=[ + None, + np.float32([2.]*dims), + ], + scale=[ + linop_identity_lib.LinearOperatorScaledIdentity( + num_rows=dims, + multiplier=np.float32(1.1), + is_positive_definite=True), + linop_diag_lib.LinearOperatorDiag( + diag=np.linspace(2.5, 3.5, dims, dtype=np.float32), + is_positive_definite=True), + ], + quadrature_grid_and_probs=(g, p), + validate_args=True) + # Ball centered at component0's mean. + sess_run_fn = lambda x: sess.run(x, feed_dict={qgrid: g, qprobs: p}) + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=2., center=0., rtol=0.005) + # Larger ball centered at component1's mean. + self.run_test_sample_consistent_log_prob( + sess_run_fn, vdm, radius=4., center=2., rtol=0.005) # TODO(jvdillon): We've tested that (i) .sample and .log_prob are consistent, # (ii) .mean, .stddev etc... and .sample are consistent. However, we haven't diff --git a/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2bc6a926dd66fd2b5796576c723345ca2014aad6 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/vector_sinh_arcsinh_diag_test.py @@ -0,0 +1,272 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for VectorSinhArcsinhDiag.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib import distributions +from tensorflow.contrib.distributions.python.ops import test_util +from tensorflow.python.platform import test + +ds = distributions +rng = np.random.RandomState(123) + + +class VectorSinhArcsinhDiagTest(test_util.VectorDistributionTestHelpers, + test.TestCase): + + def test_default_is_same_as_normal(self): + d = 10 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(1.0) + loc = rng.randn(d) + with self.test_session() as sess: + norm = ds.MultivariateNormalDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=True) + sasnorm = ds.VectorSinhArcsinhDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=True) + + x = rng.randn(5, d) + norm_pdf, sasnorm_pdf = sess.run([norm.prob(x), sasnorm.prob(x)]) + self.assertAllClose(norm_pdf, sasnorm_pdf) + + norm_samps, sasnorm_samps = sess.run( + [norm.sample(10000, seed=0), + sasnorm.sample(10000, seed=0)]) + self.assertAllClose(loc, sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.mean(axis=0), sasnorm_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + norm_samps.std(axis=0), sasnorm_samps.std(axis=0), atol=0.1) + + def test_passing_in_laplace_plus_defaults_is_same_as_laplace(self): + d = 10 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(1.2) + loc = rng.randn(d) + with self.test_session() as sess: + vlap = ds.VectorLaplaceDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=True) + sasvlap = ds.VectorSinhArcsinhDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + distribution=ds.Laplace(np.float64(0.), np.float64(1.)), + validate_args=True) + + x = rng.randn(5, d) + vlap_pdf, sasvlap_pdf = sess.run([vlap.prob(x), sasvlap.prob(x)]) + self.assertAllClose(vlap_pdf, sasvlap_pdf) + + vlap_samps, sasvlap_samps = sess.run( + [vlap.sample(10000, seed=0), + sasvlap.sample(10000, seed=0)]) + self.assertAllClose(loc, sasvlap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + vlap_samps.mean(axis=0), sasvlap_samps.mean(axis=0), atol=0.1) + self.assertAllClose( + vlap_samps.std(axis=0), sasvlap_samps.std(axis=0), atol=0.1) + + def test_tailweight_small_gives_fewer_outliers_than_normal(self): + d = 10 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(0.9) + loc = rng.randn(d) + with self.test_session() as sess: + norm = ds.MultivariateNormalDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=True) + sasnorm = ds.VectorSinhArcsinhDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + tailweight=0.1, + validate_args=True) + + # sasnorm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * d, [10] * d]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(sasnorm_lp, norm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the normal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=1), + sasnorm.sample(int(5e5), seed=1)]) + np.testing.assert_array_less( + np.percentile(norm_samps, 0.1, axis=0), + np.percentile(sasnorm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 99.9, axis=0), + np.percentile(norm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_tailweight_large_gives_more_outliers_than_normal(self): + d = 10 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(1.0) + loc = rng.randn(d) + with self.test_session() as sess: + norm = ds.MultivariateNormalDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=True) + sasnorm = ds.VectorSinhArcsinhDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + tailweight=3., + validate_args=True) + + # norm.pdf(x) is smaller on outliers (+-10 are outliers) + x = np.float64([[-10] * d, [10] * d]) # Shape [2, 10] + norm_lp, sasnorm_lp = sess.run([norm.log_prob(x), sasnorm.log_prob(x)]) + np.testing.assert_array_less(norm_lp, sasnorm_lp) + + # 0.1% quantile and 99.9% quantile are outliers, and should be more + # extreme in the sasnormal. The 97.772% quantiles should be the same. + norm_samps, sasnorm_samps = sess.run( + [norm.sample(int(5e5), seed=2), + sasnorm.sample(int(5e5), seed=2)]) + np.testing.assert_array_less( + np.percentile(sasnorm_samps, 0.1, axis=0), + np.percentile(norm_samps, 0.1, axis=0)) + np.testing.assert_array_less( + np.percentile(norm_samps, 99.9, axis=0), + np.percentile(sasnorm_samps, 99.9, axis=0)) + # 100. * sp.stats.norm.cdf(2.) + q = 100 * 0.97724986805182079 + self.assertAllClose( + np.percentile(sasnorm_samps, q, axis=0), + np.percentile(norm_samps, q, axis=0), + rtol=0.03) + self.assertAllClose( + np.percentile(sasnorm_samps, 100 - q, axis=0), + np.percentile(norm_samps, 100 - q, axis=0), + rtol=0.03) + + def test_positive_skewness_moves_mean_to_the_right(self): + d = 10 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(1.0) + loc = rng.randn(d) + with self.test_session() as sess: + sasnorm = ds.VectorSinhArcsinhDiag( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + skewness=3.0, + validate_args=True) + + sasnorm_samps = sess.run(sasnorm.sample(10000, seed=4)) + np.testing.assert_array_less(loc, sasnorm_samps.mean(axis=0)) + + def test_consistency_random_parameters_with_batch_dim(self): + b, d = 5, 2 + scale_diag = rng.rand(b, d) + scale_identity_multiplier = np.float64(1.1) + with self.test_session() as sess: + sasnorm = ds.VectorSinhArcsinhDiag( + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + skewness=rng.randn(d) * 0.5, + tailweight=rng.rand(b, d) + 0.7, + validate_args=True) + + self.run_test_sample_consistent_log_prob( + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) + self.run_test_sample_consistent_log_prob( + sess.run, + sasnorm, + radius=1.0, + center=-0.15, + rtol=0.1) + self.run_test_sample_consistent_log_prob( + sess.run, + sasnorm, + radius=1.0, + center=0.15, + rtol=0.1) + + def test_consistency_random_parameters_no_batch_dims(self): + d = 3 + scale_diag = rng.rand(d) + scale_identity_multiplier = np.float64(1.1) + with self.test_session() as sess: + sasnorm = ds.VectorSinhArcsinhDiag( + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + skewness=rng.randn(d) * 0.5, + tailweight=rng.rand(d) + 0.7, + validate_args=True) + + self.run_test_sample_consistent_log_prob( + sess.run, sasnorm, radius=1.0, center=0., rtol=0.1) + self.run_test_sample_consistent_log_prob( + sess.run, + sasnorm, + radius=1.0, + center=-0.15, + rtol=0.1) + self.run_test_sample_consistent_log_prob( + sess.run, + sasnorm, + radius=1.0, + center=0.15, + rtol=0.1) + + def test_pdf_reflected_for_negative_skewness(self): + with self.test_session() as sess: + sas_pos_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=2., + validate_args=True) + sas_neg_skew = ds.VectorSinhArcsinhDiag( + loc=[0.], + scale_identity_multiplier=1., + skewness=-2., + validate_args=True) + x = np.linspace(-2, 2, num=5).astype(np.float32).reshape(5, 1) + self.assertAllClose( + *sess.run([sas_pos_skew.prob(x), sas_neg_skew.prob(x[::-1])])) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py index 5196954aea2b48964b9a89ef217d74c7b6dd88df..bc0ec7f195af009c87020ce8c4ea18f2e713759a 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/__init__.py @@ -14,6 +14,7 @@ # ============================================================================== """Bijector Ops. +@@AbsoluteValue @@Affine @@AffineLinearOperator @@Bijector @@ -21,16 +22,23 @@ @@CholeskyOuterProduct @@ConditionalBijector @@Exp +@@Gumbel @@Identity @@Inline @@Invert +@@MaskedAutoregressiveFlow +@@Permute @@PowerTransform +@@Reshape @@Sigmoid @@SigmoidCentered @@SinhArcsinh @@SoftmaxCentered @@Softplus @@Weibull + +@@masked_autoregressive_default_template +@@masked_dense """ from __future__ import absolute_import @@ -39,15 +47,20 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value import * from tensorflow.contrib.distributions.python.ops.bijectors.affine import * from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import * from tensorflow.contrib.distributions.python.ops.bijectors.chain import * from tensorflow.contrib.distributions.python.ops.bijectors.cholesky_outer_product import * from tensorflow.contrib.distributions.python.ops.bijectors.conditional_bijector import * from tensorflow.contrib.distributions.python.ops.bijectors.exp import * +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel import * from tensorflow.contrib.distributions.python.ops.bijectors.inline import * from tensorflow.contrib.distributions.python.ops.bijectors.invert import * +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive import * +from tensorflow.contrib.distributions.python.ops.bijectors.permute import * from tensorflow.contrib.distributions.python.ops.bijectors.power_transform import * +from tensorflow.contrib.distributions.python.ops.bijectors.reshape import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid import * from tensorflow.contrib.distributions.python.ops.bijectors.sigmoid_centered import * from tensorflow.contrib.distributions.python.ops.bijectors.sinh_arcsinh import * diff --git a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py similarity index 74% rename from tensorflow/contrib/bayesflow/python/ops/variational_inference.py rename to tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py index 6316361da2accf39dfe2e77902eec06813ca7036..6049419818e18c54209f0be95d41fcecf6627b7e 100644 --- a/tensorflow/contrib/bayesflow/python/ops/variational_inference.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Variational inference. - -See the ${@python/contrib.bayesflow.variational_inference} guide. -""" +"""AbsoluteValue bijector.""" from __future__ import absolute_import from __future__ import division @@ -23,12 +20,10 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.variational_inference_impl import * +from tensorflow.contrib.distributions.python.ops.bijectors.absolute_value_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented -_allowed_symbols = [ - "elbo", "elbo_with_log_joint", "ELBOForms", "register_prior" -] +_allowed_symbols = ["AbsoluteValue"] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b84502003ab6c0c4ffdda21eea162f441509e1fa --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py @@ -0,0 +1,132 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""AbsoluteValue bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "AbsoluteValue", +] + + +class AbsoluteValue(bijector.Bijector): + """Computes `Y = g(X) = Abs(X)`, element-wise. + + This non-injective bijector allows for transformations of scalar distributions + with the absolute value function, which maps `(-inf, inf)` to `[0, inf)`. + + * For `y in (0, inf)`, `AbsoluteValue.inverse(y)` returns the set inverse + `{x in (-inf, inf) : |x| = y}` as a tuple, `-y, y`. + * `AbsoluteValue.inverse(0)` returns `0, 0`, which is not the set inverse + (the set inverse is the singleton `{0}`), but "works" in conjunction with + `TransformedDistribution` to produce a left semi-continuous pdf. + * For `y < 0`, `AbsoluteValue.inverse(y)` happily returns the + wrong thing, `-y, y`. This is done for efficiency. If + `validate_args == True`, `y < 0` will raise an exception. + + + ```python + abs = ds.bijectors.AbsoluteValue() + + abs.forward([-1., 0., 1.]) + ==> [1., 0., 1.] + + abs.inverse(1.) + ==> [-1., 1.] + + # The |dX/dY| is constant, == 1. So Log|dX/dY| == 0. + abs.inverse_log_det_jacobian(1.) + ==> [0., 0.] + + # Special case handling of 0. + abs.inverse(0.) + ==> [0., 0.] + + abs.inverse_log_det_jacobian(0.) + ==> [0., 0.] + ``` + + """ + + def __init__(self, event_ndims=0, validate_args=False, name="absolute_value"): + """Instantiates the `AbsoluteValue` bijector. + + Args: + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. Currently only zero is + supported. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness, in particular whether inputs to `inverse` and + `inverse_log_det_jacobian` are non-negative. + name: Python `str` name given to ops managed by this object. + + Raises: + ValueError: If `event_ndims` is not zero. + """ + self._graph_parents = [] + self._name = name + + event_ndims = ops.convert_to_tensor(event_ndims, name="event_ndims") + event_ndims_const = tensor_util.constant_value(event_ndims) + if event_ndims_const is not None and event_ndims_const not in (0,): + raise ValueError("event_ndims(%s) was not 0" % event_ndims_const) + else: + if validate_args: + event_ndims = control_flow_ops.with_dependencies( + [check_ops.assert_equal( + event_ndims, 0, message="event_ndims was not 0")], + event_ndims) + + with self._name_scope("init"): + super(AbsoluteValue, self).__init__( + event_ndims=event_ndims, + validate_args=validate_args, + name=name) + + def _forward(self, x): + return math_ops.abs(x) + + def _inverse(self, y): + if self.validate_args: + y = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + y) + return -y, y + + def _inverse_log_det_jacobian(self, y): + # If event_ndims = 2, + # F^{-1}(y) = (-y, y), so DF^{-1}(y) = (-1, 1), + # so Log|DF^{-1}(y)| = Log[1, 1] = [0, 0]. + batch_shape = array_ops.shape(y)[:array_ops.rank(y) - self.event_ndims] + zeros = array_ops.zeros(batch_shape, dtype=y.dtype) + if self.validate_args: + zeros = control_flow_ops.with_dependencies( + [check_ops.assert_non_negative(y, message="Argument y was negative")], + zeros) + return zeros, zeros + + @property + def _is_injective(self): + return False diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py index d8698788c141328e72651e958d9e6368d33f6aaf..05bb9c2f9bdf35e222c94db3491157893da64ebd 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_impl.py @@ -199,6 +199,11 @@ class Affine(bijector.Bijector): event_ndims, 2, message="event_ndims must be 0 or 1")], event_ndims) + if event_ndims_const == 0 and not self._is_only_identity_multiplier: + raise ValueError( + "If event_ndims == 0, the only scale argument you can pass is " + "scale_identity_multiplier. All others operate on vectors.") + # In the absence of `loc` and `scale`, we'll assume `dtype` is `float32`. dtype = dtypes.float32 @@ -321,7 +326,7 @@ class Affine(bijector.Bijector): shape_hint=shape_hint) if perturb_factor is not None: - return linalg.LinearOperatorUDVHUpdate( + return linalg.LinearOperatorLowRankUpdate( scale, u=perturb_factor, diag_update=perturb_diag, @@ -383,10 +388,11 @@ class Affine(bijector.Bijector): if self._is_only_identity_multiplier: # We don't pad in this case and instead let the fldj be applied # via broadcast. - d = math_ops.cast(array_ops.shape(x)[-1], dtype=self._scale.dtype) - one = ops.convert_to_tensor(1., self._scale.dtype) - return math_ops.log(math_ops.abs(self._scale)) * array_ops.where( - math_ops.equal(self._shaper.event_ndims, 0), one, d) + event_size = distribution_util.pick_vector( + math_ops.equal(self._shaper.event_ndims, 0), + [1], array_ops.shape(x))[-1] + event_size = math_ops.cast(event_size, dtype=self._scale.dtype) + return math_ops.log(math_ops.abs(self._scale)) * event_size return self.scale.log_abs_determinant() def _maybe_check_scale(self): diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py index ae380b5cb2bc39e06aa1e187c134d7e92f6cd92f..89043b1410370074f11f2cfa59b6b6663fa62521 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/affine_linear_operator_impl.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops.shape import _DistributionShape -from tensorflow.contrib.linalg.python.ops import linear_operator from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -27,6 +26,7 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops.distributions import bijector +from tensorflow.python.ops.linalg import linear_operator __all__ = [ @@ -66,7 +66,7 @@ class AffineLinearOperator(bijector.Bijector): Example Use: ```python - linalg = tf.contrib.linalg + linalg = tf.linalg x = [1., 2, 3] @@ -82,7 +82,7 @@ class AffineLinearOperator(bijector.Bijector): tril = [[1., 0, 0], [2, 1, 0], [3, 2, 1]] - scale = linalg.LinearOperatorTriL(tril) + scale = linalg.LinearOperatorLowerTriangular(tril) affine = AffineLinearOperator(shift, scale) # In this case, `forward` is equivalent to: # np.squeeze(np.matmul(tril, np.expand_dims(x, -1)), -1) + shift diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py index defa36a14048d35c6264c7227840ed70dcc77cbb..3ce7c26213034c7345a20faa803c94a1bfa8d579 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/chain_impl.py @@ -81,6 +81,13 @@ class Chain(bijector.Bijector): if bijectors is None: bijectors = () self._bijectors = bijectors + + for a_bijector in bijectors: + if not a_bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijector ({})".format( + a_bijector.name)) + dtype = list(set([b.dtype for b in bijectors])) if len(dtype) > 2: raise ValueError("incompatible dtypes: %s" % dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py index dc05b2f611a52dc29717c69df77a1576aa6b5693..cbd60f92a60612c6cf791b2c7708a3310c6e2b6b 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/cholesky_outer_product_impl.py @@ -43,6 +43,24 @@ class CholeskyOuterProduct(bijector.Bijector): Note: the upper-triangular part of X is ignored (whether or not its zero). + The surjectivity of g as a map from the set of n x n positive-diagonal + lower-triangular matrices to the set of SPD matrices follows immediately from + executing the Cholesky factorization algorithm on an SPD matrix A to produce a + positive-diagonal lower-triangular matrix L such that `A = L @ L.T`. + + To prove the injectivity of g, suppose that L_1 and L_2 are lower-triangular + with positive diagonals and satisfy `A = L_1 @ L_1.T = L_2 @ L_2.T`. Then + `inv(L_1) @ A @ inv(L_1).T = [inv(L_1) @ L_2] @ [inv(L_1) @ L_2].T = I`. + Setting `L_3 := inv(L_1) @ L_2`, that L_3 is a positive-diagonal + lower-triangular matrix follows from `inv(L_1)` being positive-diagonal + lower-triangular (which follows from the diagonal of a triangular matrix being + its spectrum), and that the product of two positive-diagonal lower-triangular + matrices is another positive-diagonal lower-triangular matrix. + + A simple inductive argument (proceding one column of L_3 at a time) shows + that, if `I = L_3 @ L_3.T`, with L_3 being lower-triangular with positive- + diagonal, then `L_3 = I`. Thus, `L_1 = L_2`, proving injectivity of g. + Examples: ```python diff --git a/tensorflow/contrib/boosted_trees/python/ops/ensemble_optimizer_ops.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py similarity index 76% rename from tensorflow/contrib/boosted_trees/python/ops/ensemble_optimizer_ops.py rename to tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py index f7c2e4fe5a8c52de3a4430f321f713128a6027a6..cf37aa51115ed98ab263bc03bcb297a03432a7ae 100644 --- a/tensorflow/contrib/boosted_trees/python/ops/ensemble_optimizer_ops.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Split handler custom ops.""" +"""Gumbel bijector.""" + from __future__ import absolute_import from __future__ import division from __future__ import print_function -# pylint: disable=unused-import -from tensorflow.contrib.boosted_trees.python.ops import boosted_trees_ops_loader -# pylint: enable=unused-import +# go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.boosted_trees.python.ops.gen_ensemble_optimizer_ops import * +from tensorflow.contrib.distributions.python.ops.bijectors.gumbel_impl import * # pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Gumbel"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..67f39785563255be0fe154aca3cbcf01c6a01e73 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/gumbel_impl.py @@ -0,0 +1,124 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Gumbel bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector + +__all__ = [ + "Gumbel", +] + + +class Gumbel(bijector.Bijector): + """Compute `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + + This bijector maps inputs from `[-inf, inf]` to [0, 1]`. The inverse of the + bijector applied to a uniform random variable `X ~ U(0, 1) gives back a + random variable with the + [Gumbel distribution](https://en.wikipedia.org/wiki/Gumbel_distribution): + + ```none + Y ~ Gumbel(loc, scale) + pdf(y; loc, scale) = exp( + -( (y - loc) / scale + exp(- (y - loc) / scale) ) ) / scale + ``` + """ + + def __init__(self, + loc=0., + scale=1., + event_ndims=0, + validate_args=False, + name="gumbel"): + """Instantiates the `Gumbel` bijector. + + Args: + loc: Float-like `Tensor` that is the same dtype and is + broadcastable with `scale`. + This is `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + scale: Positive Float-like `Tensor` that is the same dtype and is + broadcastable with `loc`. + This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`. + event_ndims: Python scalar indicating the number of dimensions associated + with a particular draw from the distribution. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str` name given to ops managed by this object. + """ + self._graph_parents = [] + self._name = name + self._validate_args = validate_args + with self._name_scope("init", values=[loc, scale]): + self._loc = ops.convert_to_tensor(loc, name="loc") + self._scale = ops.convert_to_tensor(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + if validate_args: + self._scale = control_flow_ops.with_dependencies([ + check_ops.assert_positive( + self._scale, message="Argument scale was not positive") + ], self._scale) + + super(Gumbel, self).__init__( + event_ndims=event_ndims, validate_args=validate_args, name=name) + + @property + def loc(self): + """The `loc` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._loc + + @property + def scale(self): + """This is `scale` in `Y = g(X) = exp(-exp(-(X - loc) / scale))`.""" + return self._scale + + def _forward(self, x): + z = (x - self.loc) / self.scale + return math_ops.exp(-math_ops.exp(-z)) + + def _inverse(self, y): + y = self._maybe_assert_valid_y(y) + return self.loc - self.scale * math_ops.log(-math_ops.log(y)) + + def _inverse_log_det_jacobian(self, y): + y = self._maybe_assert_valid_y(y) + event_dims = self._event_dims_tensor(y) + return math_ops.reduce_sum( + math_ops.log(self.scale / (-math_ops.log(y) * y)), axis=event_dims) + + def _forward_log_det_jacobian(self, x): + event_dims = self._event_dims_tensor(x) + z = (x - self.loc) / self.scale + return math_ops.reduce_sum( + -z - math_ops.exp(-z) - math_ops.log(self.scale), axis=event_dims) + + def _maybe_assert_valid_y(self, y): + if not self.validate_args: + return y + is_positive = check_ops.assert_non_negative( + y, message="Inverse transformation input must be greater than 0.") + less_than_one = check_ops.assert_less_equal( + y, + constant_op.constant(1., y.dtype), + message="Inverse transformation input must be less than or equal to 1.") + return control_flow_ops.with_dependencies([is_positive, less_than_one], y) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py index 1d0719e6a4574864ba64019b122562819606435c..2c603fe61f36dd27f4984fe6c13c11f2fb534321 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/invert_impl.py @@ -60,6 +60,10 @@ class Invert(bijector_lib.Bijector): name: Python `str`, name given to ops managed by this object. """ + if not bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError( + "Invert is not implemented for non-injective bijectors.") + self._bijector = bijector super(Invert, self).__init__( event_ndims=bijector.event_ndims, diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py similarity index 76% rename from tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py rename to tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py index b8e38b6f9bf86aef42627cf127a93ce2edd42451..132dc570f94719b6c71fb269866c943774481b7e 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_graph.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,12 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for Stochastic Computation Graphs. - -See the @{$python/contrib.bayesflow.stochastic_graph} guide. - -@@surrogate_loss -""" +"""MaskedAutoregressiveFlow bijector.""" from __future__ import absolute_import from __future__ import division @@ -25,13 +20,14 @@ from __future__ import print_function # go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.stochastic_graph_impl import * +from tensorflow.contrib.distributions.python.ops.bijectors.masked_autoregressive_impl import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented - _allowed_symbols = [ - "surrogate_loss" + "MaskedAutoregressiveFlow", + "masked_dense", + "masked_autoregressive_default_template", ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ae142883931274b594dbbafbe86bd71e75c621bc --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py @@ -0,0 +1,473 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""MaskedAutoregressiveFlow bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.layers import core as layers +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import template as template_ops +from tensorflow.python.ops import variable_scope as variable_scope_lib +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "MaskedAutoregressiveFlow", + "masked_autoregressive_default_template", + "masked_dense", +] + + +class MaskedAutoregressiveFlow(bijector_lib.Bijector): + """Affine MaskedAutoregressiveFlow bijector for vector-valued events. + + The affine autoregressive flow [1] provides a relatively simple framework for + user-specified (deep) architectures to learn a distribution over vector-valued + events. Regarding terminology, + + "Autoregressive models decompose the joint density as a product of + conditionals, and model each conditional in turn. Normalizing flows + transform a base density (e.g. a standard Gaussian) into the target density + by an invertible transformation with tractable Jacobian." [1] + + In other words, the "autoregressive property" is equivalent to the + decomposition, `p(x) = prod{ p(x[i] | x[0:i]) : i=0, ..., d }`. The provided + `shift_and_log_scale_fn`, `masked_autoregressive_default_template`, achieves + this property by zeroing out weights in its `masked_dense` layers. + + In the `tf.distributions` framework, a "normalizing flow" is implemented as a + `tf.distributions.bijectors.Bijector`. The `forward` "autoregression" + is implemented using a `tf.while_loop` and a deep neural network (DNN) with + masked weights such that the autoregressive property is automatically met in + the `inverse`. + + A `TransformedDistribution` using `MaskedAutoregressiveFlow(...)` uses the + (expensive) forward-mode calculation to draw samples and the (cheap) + reverse-mode calculation to compute log-probabilities. Conversely, a + `TransformedDistribution` using `Invert(MaskedAutoregressiveFlow(...))` uses + the (expensive) forward-mode calculation to compute log-probabilities and the + (cheap) reverse-mode calculation to compute samples. See "Example Use" + [below] for more details. + + Given a `shift_and_log_scale_fn`, the forward and inverse transformations are + (a sequence of) affine transformations. A "valid" `shift_and_log_scale_fn` + must compute each `shift` (aka `loc` or "mu" [2]) and `log(scale)` (aka + "alpha" [2]) such that each are broadcastable with the arguments to `forward` + and `inverse`, i.e., such that the calculations in `forward`, `inverse` + [below] are possible. + + For convenience, `masked_autoregressive_default_template` is offered as a + possible `shift_and_log_scale_fn` function. It implements the MADE + architecture [2]. MADE is a feed-forward network that computes a `shift` and + `log(scale)` using `masked_dense` layers in a deep neural network. Weights are + masked to ensure the autoregressive property. It is possible that this + architecture is suboptimal for your task. To build alternative networks, + either change the arguments to `masked_autoregressive_default_template`, use + the `masked_dense` function to roll-out your own, or use some other + architecture, e.g., using `tf.layers`. + + Warning: no attempt is made to validate that the `shift_and_log_scale_fn` + enforces the "autoregressive property". + + Assuming `shift_and_log_scale_fn` has valid shape and autoregressive + semantics, the forward transformation is, + + ```python + def forward(x): + y = zeros_like(x) + event_size = x.shape[-1] + for _ in range(event_size): + shift, log_scale = shift_and_log_scale_fn(y) + y = x * math_ops.exp(log_scale) + shift + return y + ``` + + and the inverse transformation is, + + ```python + def inverse(y): + shift, log_scale = shift_and_log_scale_fn(y) + return (y - shift) / math_ops.exp(log_scale) + ``` + + Notice that the `inverse` does not need a for-loop. This is because in the + forward pass each calculation of `shift` and `log_scale` is based on the `y` + calculated so far (not `x`). In the `inverse`, the `y` is fully known, thus is + equivalent to the scaling used in `forward` after `event_size` passes, i.e., + the "last" `y` used to compute `shift`, `log_scale`. (Roughly speaking, this + also proves the transform is bijective.) + + #### Example Use + + ```python + ds = tf.contrib.distributions + bs = tf.contrib.distributions.bijectors + + dims = 5 + + # A common choice for a normalizing flow is to use a Gaussian for the base + # distribution. (However, any continuous distribution would work.) E.g., + maf = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + hidden_layers=[512, 512])), + event_shape=[dims]) + + x = maf.sample() # Expensive; uses `tf.while_loop`, no Bijector caching. + maf.log_prob(x) # Almost free; uses Bijector caching. + maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching. + + # [1] also describes an "Inverse Autoregressive Flow", e.g., + iaf = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.Invert(bs.MaskedAutoregressiveFlow( + shift_and_log_scale_fn=bs.masked_autoregressive_default_template( + hidden_layers=[512, 512]))), + event_shape=[dims]) + + x = iaf.sample() # Cheap; no `tf.while_loop` despite no Bijector caching. + iaf.log_prob(x) # Almost free; uses Bijector caching. + iaf.log_prob(0.) # Expensive; uses `tf.while_loop`, no Bijector caching. + + # In many (if not most) cases the default `shift_and_log_scale_fn` will be a + # poor choice. Here's an example of using a "shift only" version and with a + # different number/depth of hidden layers. + shift_only = True + maf_no_scale_hidden2 = ds.TransformedDistribution( + distribution=ds.Normal(loc=0., scale=1.), + bijector=bs.MaskedAutoregressiveFlow( + bs.masked_autoregressive_default_template( + hidden_layers=[32], + shift_only=shift_only), + is_constant_jacobian=shift_only), + event_shape=[dims]) + ``` + + [1]: "Masked Autoregressive Flow for Density Estimation." + George Papamakarios, Theo Pavlakou, Iain Murray. Arxiv. 2017. + https://arxiv.org/abs/1705.07057 + + [2]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + """ + + def __init__(self, + shift_and_log_scale_fn, + is_constant_jacobian=False, + validate_args=False, + name=None): + """Creates the MaskedAutoregressiveFlow bijector. + + Args: + shift_and_log_scale_fn: Python `callable` which computes `shift` and + `log_scale` from both the forward domain (`x`) and the inverse domain + (`y`). Calculation must respect the "autoregressive property" (see class + docstring). Suggested default + `masked_autoregressive_default_template(hidden_layers=...)`. + Typically the function contains `tf.Variables` and is wrapped using + `tf.make_template`. Returning `None` for either (both) `shift`, + `log_scale` is equivalent to (but more efficient than) returning zero. + is_constant_jacobian: Python `bool`. Default: `False`. When `True` the + implementation assumes `log_scale` does not depend on the forward domain + (`x`) or inverse domain (`y`) values. (No validation is made; + `is_constant_jacobian=False` is always safe but possibly computationally + inefficient.) + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + """ + name = name or "masked_autoregressive_flow" + self._shift_and_log_scale_fn = shift_and_log_scale_fn + super(MaskedAutoregressiveFlow, self).__init__( + is_constant_jacobian=is_constant_jacobian, + validate_args=validate_args, + name=name) + + def _forward(self, x): + event_size = array_ops.shape(x)[-1] + def _loop_body(index, y0): + """While-loop body for autoregression calculation.""" + # Set caching device to avoid re-getting the tf.Variable for every while + # loop iteration. + with variable_scope_lib.variable_scope( + variable_scope_lib.get_variable_scope()) as vs: + if vs.caching_device is None: + vs.set_caching_device(lambda op: op.device) + shift, log_scale = self._shift_and_log_scale_fn(y0) + y = x + if log_scale is not None: + y *= math_ops.exp(log_scale) + if shift is not None: + y += shift + return index + 1, y + _, y = control_flow_ops.while_loop( + cond=lambda index, _: index < event_size, + body=_loop_body, + loop_vars=[0, array_ops.zeros_like(x, name="y0")]) + return y + + def _inverse(self, y): + shift, log_scale = self._shift_and_log_scale_fn(y) + x = y + if shift is not None: + x -= shift + if log_scale is not None: + x *= math_ops.exp(-log_scale) + return x + + def _inverse_log_det_jacobian(self, y): + _, log_scale = self._shift_and_log_scale_fn(y) + if log_scale is None: + return constant_op.constant(0., dtype=y.dtype, name="ildj") + return -math_ops.reduce_sum(log_scale, axis=-1) + + +MASK_INCLUSIVE = "inclusive" +MASK_EXCLUSIVE = "exclusive" + + +def _gen_slices(num_blocks, n_in, n_out, mask_type=MASK_EXCLUSIVE): + """Generate the slices for building an autoregressive mask.""" + # TODO(b/67594795): Better support of dynamic shape. + slices = [] + col = 0 + d_in = n_in // num_blocks + d_out = n_out // num_blocks + row = d_out if mask_type == MASK_EXCLUSIVE else 0 + for _ in range(num_blocks): + row_slice = slice(row, None) + col_slice = slice(col, col + d_in) + slices.append([row_slice, col_slice]) + col += d_in + row += d_out + return slices + + +def _gen_mask(num_blocks, + n_in, + n_out, + mask_type=MASK_EXCLUSIVE, + dtype=dtypes.float32): + """Generate the mask for building an autoregressive dense layer.""" + # TODO(b/67594795): Better support of dynamic shape. + mask = np.zeros([n_out, n_in], dtype=dtype.as_numpy_dtype()) + slices = _gen_slices(num_blocks, n_in, n_out, mask_type=mask_type) + for [row_slice, col_slice] in slices: + mask[row_slice, col_slice] = 1 + return mask + + +def masked_dense(inputs, + units, + num_blocks=None, + exclusive=False, + kernel_initializer=None, + reuse=None, + name=None, + *args, + **kwargs): + """A autoregressively masked dense layer. Analogous to `tf.layers.dense`. + + See [1] for detailed explanation. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + inputs: Tensor input. + units: Python `int` scalar representing the dimensionality of the output + space. + num_blocks: Python `int` scalar representing the number of blocks for the + MADE masks. + exclusive: Python `bool` scalar representing whether to zero the diagonal of + the mask, used for the first layer of a MADE. + kernel_initializer: Initializer function for the weight matrix. + If `None` (default), weights are initialized using the + `tf.glorot_random_initializer`. + reuse: Python `bool` scalar representing whether to reuse the weights of a + previous layer by the same name. + name: Python `str` used to describe ops managed by this function. + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + Output tensor. + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + # TODO(b/67594795): Better support of dynamic shape. + input_depth = inputs.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + + mask = _gen_mask(num_blocks, input_depth, units, + MASK_EXCLUSIVE if exclusive else MASK_INCLUSIVE).T + + if kernel_initializer is None: + kernel_initializer = init_ops.glorot_normal_initializer() + + def masked_initializer(shape, dtype=None, partition_info=None): + return mask * kernel_initializer(shape, dtype, partition_info) + + with ops.name_scope(name, "masked_dense", [inputs, units, num_blocks]): + layer = layers.Dense( + units, + kernel_initializer=masked_initializer, + kernel_constraint=lambda x: mask * x, + name=name, + dtype=inputs.dtype.base_dtype, + _scope=name, + _reuse=reuse, + *args, + **kwargs) + return layer.apply(inputs) + + +def masked_autoregressive_default_template( + hidden_layers, + shift_only=False, + activation=nn_ops.relu, + log_scale_min_clip=-5., + log_scale_max_clip=3., + log_scale_clip_gradient=False, + name=None, + *args, + **kwargs): + """Build the MADE Model [1]. + + This will be wrapped in a make_template to ensure the variables are only + created once. It takes the input and returns the `loc` ("mu" [1]) and + `log_scale` ("alpha" [1]) from the MADE network. + + Warning: This function uses `masked_dense` to create randomly initialized + `tf.Variables`. It is presumed that these will be fit, just as you would any + other neural architecture which uses `tf.layers.dense`. + + #### About Hidden Layers: + + Each element of `hidden_layers` should be greater than the `input_depth` + (i.e., `input_depth = tf.shape(input)[-1]` where `input` is the input to the + neural network). This is necessary to ensure the autoregressivity property. + + #### About Clipping: + + This function also optionally clips the `log_scale` (but possibly not its + gradient). This is useful because if `log_scale` is too small/large it might + underflow/overflow making it impossible for the `MaskedAutoregressiveFlow` + bijector to implement a bijection. Additionally, the `log_scale_clip_gradient` + `bool` indicates whether the gradient should also be clipped. The default does + not clip the gradient; this is useful because it still provides gradient + information (for fitting) yet solves the numerical stability problem. I.e., + `log_scale_clip_gradient = False` means + `grad[exp(clip(x))] = grad[x] exp(clip(x))` rather than the usual + `grad[clip(x)] exp(clip(x))`. + + [1]: "MADE: Masked Autoencoder for Distribution Estimation." + Mathieu Germain, Karol Gregor, Iain Murray, Hugo Larochelle. ICML. 2015. + https://arxiv.org/abs/1502.03509 + + Arguments: + hidden_layers: Python `list`-like of non-negative integer, scalars + indicating the number of units in each hidden layer. Default: `[512, 512]. + shift_only: Python `bool` indicating if only the `shift` term shall be + computed. Default: `False`. + activation: Activation function (callable). Explicitly setting to `None` + implies a linear activation. + log_scale_min_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The minimum value to clip by. Default: -5. + log_scale_max_clip: `float`-like scalar `Tensor`, or a `Tensor` with the + same shape as `log_scale`. The maximum value to clip by. Default: 3. + log_scale_clip_gradient: Python `bool` indicating that the gradient of + `tf.clip_by_value` should be preserved. Default: `False`. + name: A name for ops managed by this function. Default: + "masked_autoregressive_default_template". + *args: `tf.layers.dense` arguments. + **kwargs: `tf.layers.dense` keyword arguments. + + Returns: + shift: `Float`-like `Tensor` of shift terms (the "mu" in [2]). + log_scale: `Float`-like `Tensor` of log(scale) terms (the "alpha" in [2]). + + Raises: + NotImplementedError: if rightmost dimension of `inputs` is unknown prior to + graph execution. + """ + + with ops.name_scope(name, "masked_autoregressive_default_template", + values=[log_scale_min_clip, log_scale_max_clip]): + def _fn(x): + """MADE parameterized via `masked_autoregressive_default_template`.""" + # TODO(b/67594795): Better support of dynamic shape. + input_depth = x.shape.with_rank_at_least(1)[-1].value + if input_depth is None: + raise NotImplementedError( + "Rightmost dimension must be known prior to graph execution.") + input_shape = (np.int32(x.shape.as_list()) if x.shape.is_fully_defined() + else array_ops.shape(x)) + for i, units in enumerate(hidden_layers): + x = masked_dense( + inputs=x, + units=units, + num_blocks=input_depth, + exclusive=True if i == 0 else False, + activation=activation, + *args, + **kwargs) + x = masked_dense( + inputs=x, + units=(1 if shift_only else 2) * input_depth, + num_blocks=input_depth, + activation=None, + *args, + **kwargs) + if shift_only: + x = array_ops.reshape(x, shape=input_shape) + return x, None + x = array_ops.reshape( + x, shape=array_ops.concat([input_shape, [2]], axis=0)) + shift, log_scale = array_ops.unstack(x, num=2, axis=-1) + which_clip = (math_ops.clip_by_value if log_scale_clip_gradient + else _clip_by_value_preserve_grad) + log_scale = which_clip(log_scale, log_scale_min_clip, log_scale_max_clip) + return shift, log_scale + return template_ops.make_template( + "masked_autoregressive_default_template", _fn) + + +def _clip_by_value_preserve_grad(x, clip_value_min, clip_value_max, name=None): + """Clips input while leaving gradient unaltered.""" + with ops.name_scope(name, "clip_by_value_preserve_grad", + [x, clip_value_min, clip_value_max]): + clip_x = clip_ops.clip_by_value(x, clip_value_min, clip_value_max) + return x + array_ops.stop_gradient(clip_x - x) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py new file mode 100644 index 0000000000000000000000000000000000000000..a187ce22d686ee1203802ae2bfe64b0e1a3ea850 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permute bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.permute_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Permute"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d8f2f41b28a88208a19824377f93882b767f03 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py @@ -0,0 +1,138 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Permutation bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Permute", +] + + +class Permute(bijector_lib.Bijector): + """Permutes the rightmost dimension of a `Tensor`. + + ```python + bs = tf.contrib.distributions.bijectors + + reverse = bs.Permute(permutation=[2, 1, 0]) + + reverse.forward([-1., 0., 1.]) + # ==> [1., 0., -1] + + reverse.inverse([1., 0., -1]) + # ==> [-1., 0., 1.] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + Warning: `tf.estimator` may repeatedly build the graph thus + `Permute(np.random.permutation(event_size)).astype("int32"))` is not a + reliable parameterization (nor would it be even if using `tf.constant`). A + safe alternative is to use `tf.get_variable` to achieve "init once" behavior, + i.e., + + ```python + def init_once(x, name): + return tf.get_variable(name, initializer=x, trainable=False) + + Permute(permutation=init_once( + np.random.permutation(event_size).astype("int32"), + name="permutation")) + ``` + + """ + + def __init__(self, permutation, validate_args=False, name=None): + """Creates the `Permute` bijector. + + Args: + permutation: An `int`-like vector-shaped `Tensor` representing the + permutation to apply to the rightmost dimension of the transformed + `Tensor`. + validate_args: Python `bool` indicating whether arguments should be + checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if `not permutation.dtype.is_integer`. + ValueError: if `permutation` does not contain exactly one of each of + `{0, 1, ..., d}`. + """ + with ops.name_scope(name, "permute", values=[permutation]): + permutation = ops.convert_to_tensor( + permutation, + name="permutation") + if not permutation.dtype.is_integer: + raise TypeError("permutation.dtype ({}) should be `int`-like.".format( + permutation.dtype.name)) + p = tensor_util.constant_value(permutation) + if p is not None: + if set(p) != set(np.arange(p.size)): + raise ValueError("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.") + elif validate_args: + p, _ = nn_ops.top_k(-permutation, + k=array_ops.shape(permutation)[-1], + sorted=True) + permutation = control_flow_ops.with_dependencies([ + check_ops.assert_equal( + -p, math_ops.range(array_ops.size(p)), + message=("Permutation over `d` must contain exactly one of " + "each of `{0, 1, ..., d}`.")), + ], permutation) + self._permutation = permutation + super(Permute, self).__init__( + is_constant_jacobian=True, + validate_args=validate_args, + name=name or "permute") + + @property + def permutation(self): + return self._permutation + + def _forward(self, x): + return array_ops.gather(x, self.permutation, axis=-1) + + def _inverse(self, y): + return array_ops.gather( + y, + array_ops.invert_permutation(self.permutation), + axis=-1) + + def _inverse_log_det_jacobian(self, y): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + return constant_op.constant(0., dtype=x.dtype) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py new file mode 100644 index 0000000000000000000000000000000000000000..8997f7ab6929745275edb38712a5bbb0a9b25ddb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape.py @@ -0,0 +1,29 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijector.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# go/tf-wildcard-import +# pylint: disable=wildcard-import +from tensorflow.contrib.distributions.python.ops.bijectors.reshape_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ["Reshape"] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..93682639aa3be3b8f59a369dedb6ee773c468130 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py @@ -0,0 +1,297 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Reshape bijectors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bijector as bijector_lib + + +__all__ = [ + "Reshape", +] + + +class Reshape(bijector_lib.Bijector): + """Reshapes the `event_shape` of a `Tensor`. + + The semantics generally follow that of `tf.reshape()`, with + a few differences: + * The user must provide both the input and output shape, so that + the transformation can be inverted. + * The `Reshape` bijector automatically broadcasts over the leftmost + dimensions of its input (`sample_shape` and `batch_shape`); only + the rightmost `event_ndims_in` dimensions are reshaped. The + number of dimensions to reshape is inferred from the provided + `event_shape_in` (`event_ndims_in = len(event_shape_in)`). + * The `Reshape` bijector does not currently support + partially-specified shapes, i.e., those with a dimension + implicitly specified by `-1`. + + Example usage: + ```python + + bs = tf.contrib.distributions.bijectors + + reverse = bs.Reshape(event_shape_out=[1,2], + event_shape_in=[2,]) + + reverse.forward([1., 2.]) # shape [2,] + # ==> [[1., 2.]] # shape [1,2] + + reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2] + # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2] + + reverse.inverse([[1., 2.]]) # shape [1,2] + # ==> [1., 2.] # shape [2,] + + reverse.forward_log_det_jacobian(any_value) + # ==> 0. + + reverse.inverse_log_det_jacobian(any_value) + # ==> 0. + ``` + + """ + + def __init__(self, event_shape_out, event_shape_in, + validate_args=False, name=None): + """Creates a `Reshape` bijector. + + Args: + event_shape_out: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + transformed output. + event_shape_in: An `int`-like vector-shaped `Tensor` + representing the fully specified (no -1's) event shape of the + input. + validate_args: Python `bool` indicating whether arguments should + be checked for correctness. + name: Python `str`, name given to ops managed by this object. + + Raises: + TypeError: if either `event_shape_in` or `event_shape_out` has + non-vector shape (`rank > 1`), or non-integer `dtype`. + ValueError: if either `event_shape_in` or `event_shape_out` + contains non-positive entries, or if their sizes do not match + (`prod(event_shape_in)` != `prod(event_shape_out)`), or if + their dimensionality(s) cannot be statically inferred. + """ + with ops.name_scope(name, "reshape", + values=[event_shape_out, event_shape_in]): + + event_shape_out = ops.convert_to_tensor(event_shape_out, + name="event_shape_out", + preferred_dtype=dtypes.int32) + event_shape_in = ops.convert_to_tensor(event_shape_in, + name="event_shape_in", + preferred_dtype=dtypes.int32) + + # check that input shapes are positive integers + assertions = [] + assertions += self._maybe_check_valid_shape( + event_shape_out, "event_shape_out", + validate_args=validate_args) + assertions += self._maybe_check_valid_shape( + event_shape_in, "event_shape_in", validate_args=validate_args) + + # check that prod(event_shape_in) = prod(event_shape_out) + assertions += self._maybe_check_matching_sizes( + event_shape_in, event_shape_out, validate_args=validate_args) + + self._assertions = assertions + self._event_shape_in = event_shape_in + self._event_shape_out = event_shape_out + self._event_shape_in_static = tensor_util.constant_value_as_shape( + event_shape_in) + self._event_shape_out_static = tensor_util.constant_value_as_shape( + event_shape_out) + + super(Reshape, self).__init__(is_constant_jacobian=True, + validate_args=validate_args, + name=name or "reshape") + + def _maybe_check_valid_shape(self, shape_tensor, label, + validate_args=False): + """Check that a shape Tensor is int-type and positive.""" + + assertions = [] + + if not shape_tensor.dtype.is_integer: + raise TypeError("{} dtype ({}) should be `int`-like.".format( + label, shape_tensor.dtype.name)) + + shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor)) + if shape_rank is not None and shape_rank > 1: + raise ValueError("{} rank should be <= 1.".format(label)) + + s = tensor_util.constant_value(shape_tensor) + if s is not None: + if (s <= 0).any(): + raise ValueError("{} entries must be positive, but found {}".format( + label, s)) + elif validate_args: + assertions.append(check_ops.assert_positive( + shape_tensor, message="{} entries must be positive".format(label))) + + return assertions + + def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out, + validate_args=False): + """Check that prod(event_shape_in)==prod(event_shape_out).""" + + def _get_size_from_shape(shape): + """Computes size from a shape `Tensor`, statically if possible.""" + s = tensor_util.constant_value(shape) + if s is not None: + return [np.int32(np.prod(s))]*2 + return None, math_ops.reduce_prod(shape, name="size") + + # Ensure `event_shape_in` is compatible with `event_shape_out`. + event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_in) + event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking + event_shape_out) + + assertions = [] + if event_size_in_ is not None and event_size_out_ is not None: + if event_size_in_ != event_size_out_: + raise ValueError( + "Input `event_size` ({}) does not match output `event_size` ({}).". + format(event_size_in, event_size_out_)) + elif validate_args: + assertions.append(check_ops.assert_equal( + event_size_in, event_size_out, + message="Input/output `event_size`s do not match.")) + + return assertions + + def _reshape_helper(self, x, event_shape_in, event_shape_out): + """Reshape only the event_shape of an input `Tensor`.""" + + def _get_rank_from_shape(shape): + """Computes rank from a shape `Tensor`, statically if possible.""" + # Uses fact that rank is "shape of shape". + ndims = shape.shape.with_rank_at_least(1)[0].value + if ndims is not None: + return ndims, ndims + return None, array_ops.shape(shape)[0] + + event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in) + + assertions = [] + # Ensure x.event_shape is compatible with event_shape_in. + if x.shape.ndims is not None: + x_ndims_, x_ndims = [x.shape.ndims]*2 + else: + x_ndims_, x_ndims = None, array_ops.rank(x) + + if (event_ndims_in_ is not None + and x_ndims_ is not None + and x.shape.with_rank_at_least(event_ndims_in_)[ + x_ndims_-event_ndims_in_:].is_fully_defined()): + x_event_shape_, x_event_shape = [ # pylint: disable=unbalanced-tuple-unpacking + np.int32(x.shape[x_ndims_-event_ndims_in_:])]*2 + else: + x_event_shape_, x_event_shape = ( + None, array_ops.shape(x)[x_ndims-event_ndims_in:]) + + event_shape_in_ = tensor_util.constant_value(event_shape_in) + + if x_event_shape_ is not None and event_shape_in_ is not None: + if not np.equal(x_event_shape_, event_shape_in_).all(): + raise ValueError( + "Input `event_shape` ({}) does not match `event_shape_in` ({}).". + format(x_event_shape_, event_shape_in_)) + elif self.validate_args: + assertions.append(check_ops.assert_equal( + x_event_shape, event_shape_in, + message="Input `event_shape` does not match `event_shape_in`.")) + + if assertions: + x = control_flow_ops.with_dependencies(assertions, x) + + # get the parts of shape(x) that will not change + sample_and_batch_shape = array_ops.shape(x) + + ndims = (x.shape.ndims if x.shape.ndims is not None + else array_ops.rank(x)) + sample_and_batch_shape = sample_and_batch_shape[ + :(ndims - math_ops.abs(event_ndims_in))] + + new_shape = array_ops.concat( + [sample_and_batch_shape, event_shape_out], axis=0) + + return array_ops.reshape(x, new_shape) + + def _forward(self, x): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(x, + self._event_shape_in, + self._event_shape_out) + + def _inverse(self, y): + with ops.control_dependencies(self._assertions): + return self._reshape_helper(y, + self._event_shape_out, + self._event_shape_in) + + def _inverse_log_det_jacobian(self, y): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=y.dtype) + + def _forward_log_det_jacobian(self, x): + with ops.control_dependencies(self._assertions): + return constant_op.constant(0., dtype=x.dtype) + + def _forward_event_shape(self, input_shape): + self._event_shape_in_static.assert_is_compatible_with(input_shape) + return self._event_shape_out_static + + def _inverse_event_shape(self, output_shape): + self._event_shape_out_static.assert_is_compatible_with(output_shape) + return self._event_shape_in_static + + def _forward_event_shape_tensor(self, input_shape): + input_assertions = self._maybe_check_valid_shape( + input_shape, "input event shape", validate_args=self.validate_args) + input_assertions += self._maybe_check_matching_sizes( + input_shape, self._event_shape_out, + validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + input_assertions + self._assertions, self._event_shape_out) + + def _inverse_event_shape_tensor(self, output_shape): + + output_assertions = self._maybe_check_valid_shape( + output_shape, "output event shape", validate_args=self.validate_args) + output_assertions += self._maybe_check_matching_sizes( + output_shape, self._event_shape_in, validate_args=self.validate_args) + + return control_flow_ops.with_dependencies( + output_assertions + self._assertions, self._event_shape_in) diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py index dac3d812eef28b6aed291db051726d0594f7316a..3a75e4ae9495793901b0da91a5aa3982aab35852 100644 --- a/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py +++ b/tensorflow/contrib/distributions/python/ops/bijectors/sinh_arcsinh_impl.py @@ -89,18 +89,18 @@ class SinhArcsinh(bijector.Bijector): """ def __init__(self, - skewness=0., - tailweight=1., + skewness=None, + tailweight=None, event_ndims=0, validate_args=False, - name="sinh_arcsinh"): + name="SinhArcsinh"): """Instantiates the `SinhArcsinh` bijector. Args: - skewness: Skewness parameter. Float-type `Tensor`. + skewness: Skewness parameter. Float-type `Tensor`. Default is `0` + of type `float32`. tailweight: Tailweight parameter. Positive `Tensor` of same `dtype` as - `skewness` - and broadcastable `shape`. + `skewness` and broadcastable `shape`. Default is `1` of type `float32`. event_ndims: Python scalar indicating the number of dimensions associated with a particular draw from the distribution. validate_args: Python `bool` indicating whether arguments should be @@ -111,8 +111,12 @@ class SinhArcsinh(bijector.Bijector): self._name = name self._validate_args = validate_args with self._name_scope("init", values=[skewness, tailweight]): - self._skewness = ops.convert_to_tensor(skewness, name="skewness") - self._tailweight = ops.convert_to_tensor(tailweight, name="tailweight") + tailweight = 1. if tailweight is None else tailweight + skewness = 0. if skewness is None else skewness + self._skewness = ops.convert_to_tensor( + skewness, name="skewness") + self._tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=self._skewness.dtype) check_ops.assert_same_float_dtype([self._skewness, self._tailweight]) if validate_args: self._tailweight = control_flow_ops.with_dependencies([ diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..a17bb091f69b651d21f70a25c5aab61b203e62de --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/cauchy.py @@ -0,0 +1,223 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Cauchy distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops.distributions import distribution + + +__all__ = [ + "Cauchy", +] + + +class Cauchy(distribution.Distribution): + """The Cauchy distribution with location `loc` and scale `scale`. + + #### Mathematical details + + The probability density function (pdf) is, + + ```none + pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2)) + ``` + where `loc` is the location, and `scale` is the scale. + + The Cauchy distribution is a member of the [location-scale family]( + https://en.wikipedia.org/wiki/Location-scale_family), i.e. + + ```none + X ~ Cauchy(loc=0, scale=1) + Y ~ Cauchy(loc=loc, scale=scale) + Y = loc + scale * X + ``` + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar Cauchy distribution. + dist = Cauchy(loc=0., scale=3.) + + # Evaluate the cdf at 1, returning a scalar. + dist.cdf(1.) + + # Define a batch of two scalar valued Cauchy distributions. + dist = Cauchy(loc=[1, 2.], scale=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + dist.prob([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + dist.sample([3]) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two scalar valued Cauchy distributions. + # Both have median 1, but different scales. + dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.]) + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.prob(3.0) + ``` + """ + + def __init__(self, + loc, + scale, + validate_args=False, + allow_nan_stats=True, + name="Cauchy"): + """Construct Cauchy distributions with loc and and scale `loc` and `scale`. + + The parameters `loc` and `scale` must be shaped in a way that supports + broadcasting (e.g. `loc + scale` is a valid operation). + + Args: + loc: Floating point tensor; the modes of the distribution(s). + scale: Floating point tensor; the locations of the distribution(s). + Must contain only positive values. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + TypeError: if `loc` and `scale` have different `dtype`. + """ + parameters = locals() + with ops.name_scope(name, values=[loc, scale]): + with ops.control_dependencies([check_ops.assert_positive(scale)] if + validate_args else []): + self._loc = array_ops.identity(loc, name="loc") + self._scale = array_ops.identity(scale, name="scale") + check_ops.assert_same_float_dtype([self._loc, self._scale]) + super(Cauchy, self).__init__( + dtype=self._scale.dtype, + reparameterization_type=distribution.FULLY_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=[self._loc, self._scale], + name=name) + + @staticmethod + def _param_shapes(sample_shape): + return dict( + zip(("loc", "scale"), ([ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32)] * 2))) + + @property + def loc(self): + """Distribution parameter for the mean.""" + return self._loc + + @property + def scale(self): + """Distribution parameter for standard deviation.""" + return self._scale + + def _batch_shape_tensor(self): + return array_ops.broadcast_dynamic_shape( + array_ops.shape(self.loc), + array_ops.shape(self.scale)) + + def _batch_shape(self): + return array_ops.broadcast_static_shape( + self.loc.shape, + self.scale.shape) + + def _event_shape_tensor(self): + return constant_op.constant([], dtype=dtypes.int32) + + def _event_shape(self): + return tensor_shape.scalar() + + def _sample_n(self, n, seed=None): + shape = array_ops.concat([[n], self.batch_shape_tensor()], 0) + probs = random_ops.random_uniform( + shape=shape, minval=0., maxval=1., dtype=self.dtype, seed=seed) + return self._quantile(probs) + + def _log_prob(self, x): + return self._log_unnormalized_prob(x) - self._log_normalization() + + def _cdf(self, x): + return math_ops.atan(self._z(x)) / np.pi + 0.5 + + def _log_cdf(self, x): + return math_ops.log1p(2 / np.pi * math_ops.atan(self._z(x))) - np.log(2) + + def _log_unnormalized_prob(self, x): + return -math_ops.log1p(math_ops.square(self._z(x))) + + def _log_normalization(self): + return np.log(np.pi) + math_ops.log(self.scale) + + def _entropy(self): + h = np.log(4 * np.pi) + math_ops.log(self.scale) + return h * array_ops.ones_like(self.loc) + + def _quantile(self, p): + return self.loc + self.scale * math_ops.tan(np.pi * (p - 0.5)) + + def _mode(self): + return self.loc * array_ops.ones_like(self.scale) + + def _z(self, x): + """Standardize input `x`.""" + with ops.name_scope("standardize", values=[x]): + return (x - self.loc) / self.scale + + def _inv_z(self, z): + """Reconstruct input `x` from a its normalized version.""" + with ops.name_scope("reconstruct", values=[z]): + return z * self.scale + self.loc + + def _mean(self): + if self.allow_nan_stats: + return array_ops.fill(self.batch_shape_tensor(), + self.dtype.as_numpy_dtype(np.nan)) + else: + raise ValueError("`mean` is undefined for Cauchy distribution.") + + def _stddev(self): + if self.allow_nan_stats: + return array_ops.fill(self.batch_shape_tensor(), + self.dtype.as_numpy_dtype(np.nan)) + else: + raise ValueError("`stddev` is undefined for Cauchy distribution.") diff --git a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py index 2e1e68cf0587b69f055d8d747672d99383f75ed6..599c855cda434d9249187d5d154d50a8a8c49a6c 100644 --- a/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py +++ b/tensorflow/contrib/distributions/python/ops/conditional_transformed_distribution.py @@ -18,6 +18,9 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.distributions.python.ops import conditional_distribution +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import transformed_distribution from tensorflow.python.ops.distributions import util as distribution_util @@ -48,25 +51,72 @@ class ConditionalTransformedDistribution( @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _sample_n(self, n, seed=None, - bijector_kwargs=None, distribution_kwargs=None): - bijector_kwargs = bijector_kwargs or {} - distribution_kwargs = distribution_kwargs or {} + bijector_kwargs=None, + distribution_kwargs=None): sample_shape = _concat_vectors( distribution_util.pick_vector(self._needs_rotation, self._empty, [n]), self._override_batch_shape, self._override_event_shape, distribution_util.pick_vector(self._needs_rotation, [n], self._empty)) - x = self.distribution.sample(sample_shape=sample_shape, seed=seed, + distribution_kwargs = distribution_kwargs or {} + x = self.distribution.sample(sample_shape=sample_shape, + seed=seed, **distribution_kwargs) x = self._maybe_rotate_dims(x) - return self.bijector.forward(x, **bijector_kwargs) + # We'll apply the bijector in the `_call_sample_n` function. + return x + + def _call_sample_n(self, sample_shape, seed, name, + bijector_kwargs=None, + distribution_kwargs=None): + # We override `_call_sample_n` rather than `_sample_n` so we can ensure that + # the result of `self.bijector.forward` is not modified (and thus caching + # works). + with self._name_scope(name, values=[sample_shape]): + sample_shape = ops.convert_to_tensor( + sample_shape, dtype=dtypes.int32, name="sample_shape") + sample_shape, n = self._expand_sample_shape_to_vector( + sample_shape, "sample_shape") + + # First, generate samples. We will possibly generate extra samples in the + # event that we need to reinterpret the samples as part of the + # event_shape. + x = self._sample_n(n, seed, bijector_kwargs, distribution_kwargs) + + # Next, we reshape `x` into its final form. We do this prior to the call + # to the bijector to ensure that the bijector caching works. + batch_event_shape = array_ops.shape(x)[1:] + final_shape = array_ops.concat([sample_shape, batch_event_shape], 0) + x = array_ops.reshape(x, final_shape) + + # Finally, we apply the bijector's forward transformation. For caching to + # work, it is imperative that this is the last modification to the + # returned result. + bijector_kwargs = bijector_kwargs or {} + y = self.bijector.forward(x, **bijector_kwargs) + y = self._set_sample_static_shape(y, sample_shape) + + return y @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) def _log_prob(self, y, bijector_kwargs=None, distribution_kwargs=None): + # For caching to work, it is imperative that the bijector is the first to + # modify the input. bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + if self.bijector._is_injective: # pylint: disable=protected-access + return self._finish_log_prob_for_one_fiber(y, x, ildj, + distribution_kwargs) + + lp_on_fibers = [ + self._finish_log_prob_for_one_fiber(y, x_i, ildj_i, distribution_kwargs) + for x_i, ildj_i in zip(x, ildj)] + return math_ops.reduce_logsumexp(array_ops.stack(lp_on_fibers), axis=0) + + def _finish_log_prob_for_one_fiber(self, y, x, ildj, distribution_kwargs): + """Finish computation of log_prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) log_prob = self.distribution.log_prob(x, **distribution_kwargs) if self._is_maybe_event_override: @@ -79,6 +129,16 @@ class ConditionalTransformedDistribution( distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) ildj = self.bijector.inverse_log_det_jacobian(y, **bijector_kwargs) + if self.bijector._is_injective: # pylint: disable=protected-access + return self._finish_prob_for_one_fiber(y, x, ildj, distribution_kwargs) + + prob_on_fibers = [ + self._finish_prob_for_one_fiber(y, x_i, ildj_i, distribution_kwargs) + for x_i, ildj_i in zip(x, ildj)] + return sum(prob_on_fibers) + + def _finish_prob_for_one_fiber(self, y, x, ildj, distribution_kwargs): + """Finish computation of prob on one element of the inverse image.""" x = self._maybe_rotate_dims(x, rotate_right=True) prob = self.distribution.prob(x, **distribution_kwargs) if self._is_maybe_event_override: @@ -90,6 +150,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("log_cdf is not implemented when overriding " "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("log_cdf is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -100,6 +163,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("cdf is not implemented when overriding " "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("cdf is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -111,6 +177,9 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("log_survival_function is not implemented when " "overriding event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("log_survival_function is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) @@ -122,7 +191,26 @@ class ConditionalTransformedDistribution( if self._is_maybe_event_override: raise NotImplementedError("survival_function is not implemented when " "overriding event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("survival_function is not implemented when " + "bijector is not injective.") bijector_kwargs = bijector_kwargs or {} distribution_kwargs = distribution_kwargs or {} x = self.bijector.inverse(y, **bijector_kwargs) return self.distribution.survival_function(x, **distribution_kwargs) + + @distribution_util.AppendDocstring(kwargs_dict=_condition_kwargs_dict) + def _quantile(self, value, bijector_kwargs=None, distribution_kwargs=None): + if self._is_maybe_event_override: + raise NotImplementedError("quantile is not implemented when overriding " + "event_shape") + if not self.bijector._is_injective: # pylint: disable=protected-access + raise NotImplementedError("quantile is not implemented when " + "bijector is not injective.") + bijector_kwargs = bijector_kwargs or {} + distribution_kwargs = distribution_kwargs or {} + # x_q is the "qth quantile" of X iff q = P[X <= x_q]. Now, since X = + # g^{-1}(Y), q = P[X <= x_q] = P[g^{-1}(Y) <= x_q] = P[Y <= g(x_q)], + # implies the qth quantile of Y is g(x_q). + inv_cdf = self.distribution.quantile(value, **distribution_kwargs) + return self.bijector.forward(inv_cdf, **bijector_kwargs) diff --git a/tensorflow/contrib/distributions/python/ops/distribution_util.py b/tensorflow/contrib/distributions/python/ops/distribution_util.py index cb74f2b3589313a1502e2905b7cddfe28e1aa4f6..869b5698e57d199755ce1686a74a1eafe3b73e7d 100644 --- a/tensorflow/contrib/distributions/python/ops/distribution_util.py +++ b/tensorflow/contrib/distributions/python/ops/distribution_util.py @@ -27,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib from tensorflow.python.ops.distributions.util import * # pylint: disable=wildcard-import @@ -159,7 +160,7 @@ def make_tril_scale( scale_tril = array_ops.matrix_set_diag(scale_tril, tril_diag) - return linalg.LinearOperatorTriL( + return linalg.LinearOperatorLowerTriangular( tril=_maybe_attach_assertion(scale_tril), is_non_singular=True, is_self_adjoint=False, @@ -377,6 +378,30 @@ def prefer_static_broadcast_shape( return array_ops.broadcast_dynamic_shape(shape1_, shape2_) +def get_broadcast_shape(*tensors): + """Get broadcast shape as a Python list of integers (preferred) or `Tensor`. + + Args: + *tensors: One or more `Tensor` objects (already converted!). + + Returns: + broadcast shape: Python list (if shapes determined statically), otherwise + an `int32` `Tensor`. + """ + # Try static. + s_shape = tensors[0].shape + for t in tensors[1:]: + s_shape = array_ops.broadcast_static_shape(s_shape, t.shape) + if s_shape.is_fully_defined(): + return s_shape.as_list() + + # Fallback on dynamic. + d_shape = array_ops.shape(tensors[0]) + for t in tensors[1:]: + d_shape = array_ops.broadcast_dynamic_shape(d_shape, array_ops.shape(t)) + return d_shape + + def is_diagonal_scale(scale): """Returns `True` if `scale` is a `LinearOperator` that is known to be diag. @@ -395,3 +420,70 @@ def is_diagonal_scale(scale): return (isinstance(scale, linalg.LinearOperatorIdentity) or isinstance(scale, linalg.LinearOperatorScaledIdentity) or isinstance(scale, linalg.LinearOperatorDiag)) + + +def maybe_check_scalar_distribution( + distribution, expected_base_dtype, validate_args): + """Helper which checks validity of a scalar `distribution` init arg. + + Valid here means: + + * `distribution` has scalar batch and event shapes. + * `distribution` is `FULLY_REPARAMETERIZED` + * `distribution` has expected dtype. + + Args: + distribution: `Distribution`-like object. + expected_base_dtype: `TensorFlow` `dtype`. + validate_args: Python `bool`. Whether to do additional checks: + (i) check that reparameterization_type is `FULLY_REPARAMETERIZED`. + (ii) add `tf.Assert` ops to the graph to enforce that distribution + is scalar in the event that this cannot be determined statically. + + Returns: + List of `tf.Assert` ops to run to enforce validity checks that could not + be statically determined. Empty if `not validate_args`. + + Raises: + ValueError: If validate_args and distribution is not FULLY_REPARAMETERIZED + ValueError: If distribution is statically determined to not have both + scalar batch and scalar event shapes. + """ + if distribution.dtype != expected_base_dtype: + raise TypeError("dtype mismatch; " + "distribution.dtype=\"{}\" is not \"{}\"".format( + distribution.dtype.name, expected_base_dtype.name)) + + # Although `reparameterization_type` is a static property, we guard it by + # `validate_args`. This allows users to use a `distribution` which is not + # reparameterized itself. However, we tacitly assume that although the + # distribution is not reparameterized, it only depends on non-trainable + # variables. + if validate_args and (distribution.reparameterization_type + != distribution_lib.FULLY_REPARAMETERIZED): + raise ValueError("Base distribution should be reparameterized or be " + "a function of non-trainable variables; " + "distribution.reparameterization_type = \"{}\" " + "!= \"FULLY_REPARAMETERIZED\".".format( + distribution.reparameterization_type)) + with ops.name_scope(name="check_distribution"): + assertions = [] + def check_is_scalar(is_scalar, name): + is_scalar_ = static_value(is_scalar) + if is_scalar_ is not None: + if not is_scalar_: + raise ValueError("distribution must be scalar; " + "distribution.{}=False is not True".format(name)) + elif validate_args: + assertions.append(check_ops.assert_equal( + is_scalar, True, + message=("distribution must be scalar; " + "distribution.{}=False is not True".format(name)))) + check_is_scalar(distribution.is_scalar_event(), "is_scalar_event") + check_is_scalar(distribution.is_scalar_batch(), "is_scalar_batch") + return assertions + + +def static_value(x): + """Returns the static value of a `Tensor` or `None`.""" + return tensor_util.constant_value(ops.convert_to_tensor(x)) diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py new file mode 100644 index 0000000000000000000000000000000000000000..6a74ca9a0ae1ad30081d21cc15a65be052a99e2a --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/independent.py @@ -0,0 +1,256 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The Independent distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import distribution as distribution_lib + + +class Independent(distribution_lib.Distribution): + """Independent distribution from batch of distributions. + + This distribution is useful for regarding a collection of independent, + non-identical distributions as a single random variable. For example, the + `Indpendent` distribution composed of a collection of `Bernoulli` + distributions might define a distribution over an image (where each + `Bernoulli` is a distribution over each pixel). + + More precisely, a collection of `B` (independent) `E`-variate random variables + (rv) `{X_1, ..., X_B}`, can be regarded as a `[B, E]`-variate random variable + `(X_1, ..., X_B)` with probability + `p(x_1, ..., x_B) = p_1(x_1) * ... * p_B(x_B)` where `p_b(X_b)` is the + probability of the `b`-th rv. More generally `B, E` can be arbitrary shapes. + + Similarly, the `Independent` distribution specifies a distribution over `[B, + E]`-shaped events. It operates by reinterpreting the rightmost batch dims as + part of the event dimensions. The `reinterpreted_batch_ndims` parameter + controls the number of batch dims which are absorbed as event dims; + `reinterpreted_batch_ndims < len(batch_shape)`. For example, the `log_prob` + function entails a `reduce_sum` over the rightmost `reinterpreted_batch_ndims` + after calling the base distribution's `log_prob`. In other words, since the + batch dimension(s) index independent distributions, the resultant multivariate + will have independent components. + + #### Mathematical Details + + The probability function is, + + ```none + prob(x; reinterpreted_batch_ndims) = tf.reduce_prod( + dist.prob(x), + axis=-1-range(reinterpreted_batch_ndims)) + ``` + + #### Examples + + ```python + ds = tf.contrib.distributions + + # Make independent distribution from a 2-batch Normal. + ind = ds.Independent( + distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]), + reinterpreted_batch_ndims=1) + + # All batch dims have been "absorbed" into event dims. + ind.batch_shape # ==> [] + ind.event_shape # ==> [2] + + # Make independent distribution from a 2-batch bivariate Normal. + ind = ds.Independent( + distribution=ds.MultivariateNormalDiag( + loc=[[-1., 1], [1, -1]], + scale_identity_multiplier=[1., 0.5]), + reinterpreted_batch_ndims=1) + + # All batch dims have been "absorbed" into event dims. + ind.batch_shape # ==> [] + ind.event_shape # ==> [2, 2] + ``` + + """ + + def __init__( + self, distribution, reinterpreted_batch_ndims=None, + validate_args=False, name=None): + """Construct a `Independent` distribution. + + Args: + distribution: The base distribution instance to transform. Typically an + instance of `Distribution`. + reinterpreted_batch_ndims: Scalar, integer number of rightmost batch dims + which will be regarded as event dims. When `None` all but the first + batch axis (batch axis 0) will be transferred to event dimensions + (analogous to `tf.layers.flatten`). + validate_args: Python `bool`. Whether to validate input with asserts. + If `validate_args` is `False`, and the inputs are invalid, + correct behavior is not guaranteed. + name: The name for ops managed by the distribution. + Default value: `Independent + distribution.name`. + + Raises: + ValueError: if `reinterpreted_batch_ndims` exceeds + `distribution.batch_ndims` + """ + parameters = locals() + name = name or "Independent" + distribution.name + self._distribution = distribution + with ops.name_scope(name): + if reinterpreted_batch_ndims is None: + reinterpreted_batch_ndims = self._get_default_reinterpreted_batch_ndims( + distribution) + reinterpreted_batch_ndims = ops.convert_to_tensor( + reinterpreted_batch_ndims, + dtype=dtypes.int32, + name="reinterpreted_batch_ndims") + self._reinterpreted_batch_ndims = reinterpreted_batch_ndims + self._static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) + if self._static_reinterpreted_batch_ndims is not None: + self._reinterpreted_batch_ndims = self._static_reinterpreted_batch_ndims + super(Independent, self).__init__( + dtype=self._distribution.dtype, + reparameterization_type=self._distribution.reparameterization_type, + validate_args=validate_args, + allow_nan_stats=self._distribution.allow_nan_stats, + parameters=parameters, + graph_parents=( + [reinterpreted_batch_ndims] + + distribution._graph_parents), # pylint: disable=protected-access + name=name) + self._runtime_assertions = self._make_runtime_assertions( + distribution, reinterpreted_batch_ndims, validate_args) + + @property + def distribution(self): + return self._distribution + + @property + def reinterpreted_batch_ndims(self): + return self._reinterpreted_batch_ndims + + def _batch_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + batch_shape = self.distribution.batch_shape_tensor() + batch_ndims = (batch_shape.shape[0].value + if batch_shape.shape.with_rank_at_least(1)[0].value + else array_ops.shape(batch_shape)[0]) + return batch_shape[:batch_ndims - self.reinterpreted_batch_ndims] + + def _batch_shape(self): + batch_shape = self.distribution.batch_shape + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): + return tensor_shape.TensorShape(None) + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims + return batch_shape[:d] + + def _event_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + batch_shape = self.distribution.batch_shape_tensor() + batch_ndims = (batch_shape.shape[0].value + if batch_shape.shape.with_rank_at_least(1)[0].value + else array_ops.shape(batch_shape)[0]) + return array_ops.concat([ + batch_shape[batch_ndims - self.reinterpreted_batch_ndims:], + self.distribution.event_shape_tensor(), + ], axis=0) + + def _event_shape(self): + batch_shape = self.distribution.batch_shape + if (self._static_reinterpreted_batch_ndims is None + or batch_shape.ndims is None): + return tensor_shape.TensorShape(None) + d = batch_shape.ndims - self._static_reinterpreted_batch_ndims + return batch_shape[d:].concatenate(self.distribution.event_shape) + + def _sample_n(self, n, seed): + with ops.control_dependencies(self._runtime_assertions): + return self.distribution.sample(sample_shape=n, seed=seed) + + def _log_prob(self, x): + with ops.control_dependencies(self._runtime_assertions): + return self._reduce_sum(self.distribution.log_prob(x)) + + def _entropy(self): + with ops.control_dependencies(self._runtime_assertions): + return self._reduce_sum(self.distribution.entropy()) + + def _mean(self): + with ops.control_dependencies(self._runtime_assertions): + return self.distribution.mean() + + def _variance(self): + with ops.control_dependencies(self._runtime_assertions): + return self.distribution.variance() + + def _stddev(self): + with ops.control_dependencies(self._runtime_assertions): + return self.distribution.stddev() + + def _mode(self): + with ops.control_dependencies(self._runtime_assertions): + return self.distribution.mode() + + def _make_runtime_assertions( + self, distribution, reinterpreted_batch_ndims, validate_args): + assertions = [] + static_reinterpreted_batch_ndims = tensor_util.constant_value( + reinterpreted_batch_ndims) + batch_ndims = distribution.batch_shape.ndims + if batch_ndims is not None and static_reinterpreted_batch_ndims is not None: + if static_reinterpreted_batch_ndims > batch_ndims: + raise ValueError("reinterpreted_batch_ndims({}) cannot exceed " + "distribution.batch_ndims({})".format( + static_reinterpreted_batch_ndims, batch_ndims)) + elif validate_args: + batch_shape = distribution.batch_shape_tensor() + batch_ndims = ( + batch_shape.shape[0].value + if batch_shape.shape.with_rank_at_least(1)[0].value is not None + else array_ops.shape(batch_shape)[0]) + assertions.append(check_ops.assert_less_equal( + reinterpreted_batch_ndims, batch_ndims, + message=("reinterpreted_batch_ndims cannot exceed " + "distribution.batch_ndims"))) + return assertions + + def _reduce_sum(self, stat): + if self._static_reinterpreted_batch_ndims is None: + range_ = math_ops.range(self._reinterpreted_batch_ndims) + else: + range_ = np.arange(self._static_reinterpreted_batch_ndims) + return math_ops.reduce_sum(stat, axis=-1-range_) + + def _get_default_reinterpreted_batch_ndims(self, distribution): + """Computes the default value for reinterpreted_batch_ndim __init__ arg.""" + ndims = distribution.batch_shape.ndims + if ndims is None: + which_maximum = math_ops.maximum + ndims = array_ops.shape(distribution.batch_shape_tensor())[0] + else: + which_maximum = np.maximum + return which_maximum(0, ndims - 1) diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py index 5ba91693a99a34c3d7a455a838b7d4ea024513fb..e676931d9145e72907d990148ee2d180e0da0258 100644 --- a/tensorflow/contrib/distributions/python/ops/mixture.py +++ b/tensorflow/contrib/distributions/python/ops/mixture.py @@ -291,9 +291,6 @@ class Mixture(distribution.Distribution): mixture_log_cdf = math_ops.reduce_logsumexp(concatted_log_cdfs, [0]) return mixture_log_cdf - def _prob(self, x): - return math_ops.exp(self._log_prob(x)) - def _sample_n(self, n, seed=None): with ops.control_dependencies(self._assertions): n = ops.convert_to_tensor(n, name="n") diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py new file mode 100644 index 0000000000000000000000000000000000000000..5558ef0f255db684b229d129666634e50c625887 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py @@ -0,0 +1,339 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The same-family Mixture distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops.distributions import distribution +from tensorflow.python.ops.distributions import util as distribution_util + + +class MixtureSameFamily(distribution.Distribution): + """Mixture (same-family) distribution. + + The `MixtureSameFamily` distribution implements a (batch of) mixture + distribution where all components are from different parameterizations of the + same distribution type. It is parameterized by a `Categorical` "selecting + distribution" (over `k` components) and a components distribution, i.e., a + `Distribution` with a rightmost batch shape (equal to `[k]`) which indexes + each (batch of) component. + + #### Examples + + ```python + import matplotlib.pyplot as plt + ds = tf.contrib.distributions + + ### Create a mixture of two scalar Gaussians: + + gm = ds.MixtureSameFamily( + mixture_distribution=ds.Categorical( + probs=[0.3, 0.7]), + components_distribution=ds.Normal( + loc=[-1., 1], # One for each component. + scale=[0.1, 0.5])) # And same here. + + gm.mean() + # ==> 0.4 + + gm.variance() + # ==> 1.018 + + # Plot PDF. + x = np.linspace(-2., 3., int(1e4), dtype=np.float32) + plt.plot(x, gm.prob(x).eval()); + + ### Create a mixture of two Bivariate Gaussians: + + gm = ds.MixtureSameFamily( + mixture_distribution=ds.Categorical( + probs=[0.3, 0.7]), + components_distribution=ds.MultivariateNormalDiag( + loc=[[-1., 1], # component 1 + [1, -1]], # component 2 + scale_identity_multiplier=[.3, .6])) + + gm.mean() + # ==> array([ 0.4, -0.4], dtype=float32) + + gm.covariance() + # ==> array([[ 1.119, -0.84], + # [-0.84, 1.119]], dtype=float32) + + # Plot PDF contours. + def meshgrid(x, y=x): + [gx, gy] = np.meshgrid(x, y, indexing='ij') + gx, gy = np.float32(gx), np.float32(gy) + grid = np.concatenate([gx.ravel()[None, :], gy.ravel()[None, :]], axis=0) + return grid.T.reshape(x.size, y.size, 2) + grid = meshgrid(np.linspace(-2, 2, 100, dtype=np.float32)) + plt.contour(grid[..., 0], grid[..., 1], gm.prob(grid).eval()); + + ``` + + """ + + def __init__(self, + mixture_distribution, + components_distribution, + validate_args=False, + allow_nan_stats=True, + name="MixtureSameFamily"): + """Construct a `MixtureSameFamily` distribution. + + Args: + mixture_distribution: `tf.distributions.Categorical`-like instance. + Manages the probability of selecting components. The number of + categories must match the rightmost batch dimension of the + `components_distribution`. Must have either scalar `batch_shape` or + `batch_shape` matching `components_distribution.batch_shape[:-1]`. + components_distribution: `tf.distributions.Distribution`-like instance. + Right-most batch dimension indexes components. + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, statistics + (e.g., mean, mode, variance) use the value "`NaN`" to indicate the + result is undefined. When `False`, an exception is raised if one or + more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + ValueError: `if not mixture_distribution.dtype.is_integer`. + ValueError: if mixture_distribution does not have scalar `event_shape`. + ValueError: if `mixture_distribution.batch_shape` and + `components_distribution.batch_shape[:-1]` are both fully defined and + the former is neither scalar nor equal to the latter. + ValueError: if `mixture_distribution` categories does not equal + `components_distribution` rightmost batch shape. + """ + parameters = locals() + with ops.name_scope(name): + self._mixture_distribution = mixture_distribution + self._components_distribution = components_distribution + self._runtime_assertions = [] + + s = components_distribution.event_shape_tensor() + self._event_ndims = (s.shape[0].value + if s.shape.with_rank_at_least(1)[0].value is not None + else array_ops.shape(s)[0]) + + if not mixture_distribution.dtype.is_integer: + raise ValueError( + "`mixture_distribution.dtype` ({}) is not over integers".format( + mixture_distribution.dtype.name)) + + if (mixture_distribution.event_shape.ndims is not None + and mixture_distribution.event_shape.ndims != 0): + raise ValueError("`mixture_distribution` must have scalar `event_dim`s") + elif validate_args: + self._runtime_assertions += [ + control_flow_ops.assert_has_rank( + mixture_distribution.event_shape_tensor(), 0, + message="`mixture_distribution` must have scalar `event_dim`s"), + ] + + mdbs = mixture_distribution.batch_shape + cdbs = components_distribution.batch_shape.with_rank_at_least(1)[:-1] + if mdbs.is_fully_defined() and cdbs.is_fully_defined(): + if mdbs.ndims != 0 and mdbs != cdbs: + raise ValueError( + "`mixture_distribution.batch_shape` (`{}`) is not " + "compatible with `components_distribution.batch_shape` " + "(`{}`)".format(mdbs.as_list(), cdbs.as_list())) + elif validate_args: + mdbs = mixture_distribution.batch_shape_tensor() + cdbs = components_distribution.batch_shape_tensor()[:-1] + self._runtime_assertions += [ + control_flow_ops.assert_equal( + distribution_util.pick_vector( + mixture_distribution.is_scalar_batch(), cdbs, mdbs), + cdbs, + message=( + "`mixture_distribution.batch_shape` is not " + "compatible with `components_distribution.batch_shape`"))] + + km = mixture_distribution.logits.shape.with_rank_at_least(1)[-1].value + kc = components_distribution.batch_shape.with_rank_at_least(1)[-1].value + if km is not None and kc is not None and km != kc: + raise ValueError("`mixture_distribution components` ({}) does not " + "equal `components_distribution.batch_shape[-1]` " + "({})".format(km, kc)) + elif validate_args: + km = array_ops.shape(mixture_distribution.logits)[-1] + kc = components_distribution.batch_shape_tensor()[-1] + self._runtime_assertions += [ + control_flow_ops.assert_equal( + km, kc, + message=("`mixture_distribution components` does not equal " + "`components_distribution.batch_shape[-1:]`")), + ] + elif km is None: + km = array_ops.shape(mixture_distribution.logits)[-1] + + self._num_components = km + + super(MixtureSameFamily, self).__init__( + dtype=self._components_distribution.dtype, + reparameterization_type=distribution.NOT_REPARAMETERIZED, + validate_args=validate_args, + allow_nan_stats=allow_nan_stats, + parameters=parameters, + graph_parents=( + self._mixture_distribution._graph_parents # pylint: disable=protected-access + + self._components_distribution._graph_parents), # pylint: disable=protected-access + name=name) + + @property + def mixture_distribution(self): + return self._mixture_distribution + + @property + def components_distribution(self): + return self._components_distribution + + def _batch_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return self.components_distribution.batch_shape_tensor()[:-1] + + def _batch_shape(self): + return self.components_distribution.batch_shape.with_rank_at_least(1)[:-1] + + def _event_shape_tensor(self): + with ops.control_dependencies(self._runtime_assertions): + return self.components_distribution.event_shape_tensor() + + def _event_shape(self): + return self.components_distribution.event_shape + + def _sample_n(self, n, seed): + with ops.control_dependencies(self._runtime_assertions): + x = self.components_distribution.sample(n) # [n, B, k, E] + # TODO(jvdillon): Consider using tf.gather (by way of index unrolling). + npdt = x.dtype.as_numpy_dtype + mask = array_ops.one_hot( + indices=self.mixture_distribution.sample(n), # [n, B] + depth=self._num_components, # == k + on_value=np.ones([], dtype=npdt), + off_value=np.zeros([], dtype=npdt)) # [n, B, k] + mask = self._pad_mix_dims(mask) # [n, B, k, [1]*e] + return math_ops.reduce_sum( + x * mask, axis=-1 - self._event_ndims) # [n, B, E] + + def _log_prob(self, x): + with ops.control_dependencies(self._runtime_assertions): + x = self._pad_sample_dims(x) + log_prob_x = self.components_distribution.log_prob(x) # [S, B, k] + log_mix_prob = nn_ops.log_softmax( + self.mixture_distribution.logits, dim=-1) # [B, k] + return math_ops.reduce_logsumexp( + log_prob_x + log_mix_prob, axis=-1) # [S, B] + + def _mean(self): + with ops.control_dependencies(self._runtime_assertions): + probs = self._pad_mix_dims( + self.mixture_distribution.probs) # [B, k, [1]*e] + return math_ops.reduce_sum( + probs * self.components_distribution.mean(), + axis=-1 - self._event_ndims) # [B, E] + + def _log_cdf(self, x): + x = self._pad_sample_dims(x) + log_cdf_x = self.components_distribution.log_cdf(x) # [S, B, k] + log_mix_prob = nn_ops.log_softmax( + self.mixture_distribution.logits, dim=-1) # [B, k] + return math_ops.reduce_logsumexp( + log_cdf_x + log_mix_prob, axis=-1) # [S, B] + + def _variance(self): + with ops.control_dependencies(self._runtime_assertions): + # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) + probs = self._pad_mix_dims( + self.mixture_distribution.probs) # [B, k, [1]*e] + mean_cond_var = math_ops.reduce_sum( + probs * self.components_distribution.variance(), + axis=-1 - self._event_ndims) # [B, E] + var_cond_mean = math_ops.reduce_sum( + probs * math_ops.squared_difference( + self.components_distribution.mean(), + self._pad_sample_dims(self._mean())), + axis=-1 - self._event_ndims) # [B, E] + return mean_cond_var + var_cond_mean # [B, E] + + def _covariance(self): + static_event_ndims = self.event_shape.ndims + if static_event_ndims != 1: + # Covariance is defined only for vector distributions. + raise NotImplementedError("covariance is not implemented") + + with ops.control_dependencies(self._runtime_assertions): + # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) + probs = self._pad_mix_dims(self._pad_mix_dims( + self.mixture_distribution.probs)) # [B, k, 1, 1] + mean_cond_var = math_ops.reduce_sum( + probs * self.components_distribution.covariance(), + axis=-3) # [B, e, e] + var_cond_mean = math_ops.reduce_sum( + probs * _outer_squared_difference( + self.components_distribution.mean(), + self._pad_sample_dims(self._mean())), + axis=-3) # [B, e, e] + return mean_cond_var + var_cond_mean # [B, e, e] + + def _pad_sample_dims(self, x): + with ops.name_scope("pad_sample_dims", values=[x]): + ndims = x.shape.ndims if x.shape.ndims is not None else array_ops.rank(x) + shape = array_ops.shape(x) + d = ndims - self._event_ndims + x = array_ops.reshape(x, shape=array_ops.concat([ + shape[:d], [1], shape[d:]], axis=0)) + return x + + def _pad_mix_dims(self, x): + with ops.name_scope("pad_mix_dims", values=[x]): + def _get_ndims(d): + if d.batch_shape.ndims is not None: + return d.batch_shape.ndims + return array_ops.shape(d.batch_shape_tensor())[0] + dist_batch_ndims = _get_ndims(self) + cat_batch_ndims = _get_ndims(self.mixture_distribution) + bnd = distribution_util.pick_vector( + self.mixture_distribution.is_scalar_batch(), + [dist_batch_ndims], [cat_batch_ndims])[0] + s = array_ops.shape(x) + x = array_ops.reshape(x, shape=array_ops.concat([ + s[:-1], + array_ops.ones([bnd], dtype=dtypes.int32), + s[-1:], + array_ops.ones([self._event_ndims], dtype=dtypes.int32), + ], axis=0)) + return x + + +def _outer_squared_difference(x, y): + """Convenience function analogous to tf.squared_difference.""" + z = x - y + return z[..., array_ops.newaxis, :] * z[..., array_ops.newaxis] diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py index ee3e02e0203a3338b7e6a40b7e3ff30c0a0940f0..040bc230722194316b8a74627344e315a2578281 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py @@ -237,7 +237,7 @@ class MultivariateNormalDiagPlusLowRank( scale_perturb_diag, name="scale_perturb_diag") if has_low_rank: - scale = linalg.LinearOperatorUDVHUpdate( + scale = linalg.LinearOperatorLowRankUpdate( scale, u=scale_perturb_factor, diag_update=scale_perturb_diag, diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 221eed547bacd59d3c0d065f386fe45970f9bae9..f9952b2069d6dfd2593e6bd71ede0badf44cdf98 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -174,8 +174,8 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) # No need to validate that covariance_matrix is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. + # LinearOperatorLowerTriangular has an assert_non_singular method that + # is called by the Bijector. # However, cholesky() ignores the upper triangular part, so we do need # to separately assert symmetric. scale_tril = linalg_ops.cholesky(covariance_matrix) diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py index 50c7ba418be5b66127a3fde9f02a39b8f52ff841..300bdd5f6064a1cc9c336689ac4fae04338edb30 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py @@ -18,16 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors import AffineLinearOperator from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import kullback_leibler from tensorflow.python.ops.distributions import normal from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -92,7 +91,7 @@ class MultivariateNormalLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate Gaussian. mu = [1., 2, 3] @@ -106,7 +105,7 @@ class MultivariateNormalLinearOperator( mvn = ds.MultivariateNormalLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale)) + scale=la.LinearOperatorLowerTriangular(scale)) # Covariance agrees with cholesky(cov) parameterization. mvn.covariance().eval() @@ -243,8 +242,8 @@ class MultivariateNormalLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) else: @@ -254,8 +253,8 @@ class MultivariateNormalLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: @@ -299,7 +298,10 @@ def _kl_brute_force(a, b, name=None): def squared_frobenius_norm(x): """Helper to make KL calculation slightly more readable.""" # http://mathworld.wolfram.com/FrobeniusNorm.html - return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) + # The gradient of KL[p,q] is not defined when p==q. The culprit is + # linalg_ops.norm, i.e., we cannot use the commented out code. + # return math_ops.square(linalg_ops.norm(x, ord="fro", axis=[-2, -1])) + return math_ops.reduce_sum(math_ops.square(x), axis=[-2, -1]) # TODO(b/35041439): See also b/35040945. Remove this function once LinOp # supports something like: diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py index 48c4dddc8133d408e1beb7a8aef2abd676895fe3..260dcc18f513d5440d3d39368539274c03faa72a 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py @@ -121,6 +121,14 @@ class MultivariateNormalTriL( [-10, 0, 9]] # shape: [2, 3] mvn.prob(x).eval() # shape: [2] + # Instantiate a "learnable" MVN. + dims = 4 + with tf.variable_scope("model"): + mvn = ds.MultivariateNormalTriL( + loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"), + scale_tril=ds.fill_triangular( + tf.get_variable(shape=[dims * (dims + 1) / 2], + dtype=tf.float32, name="chol_Sigma"))) ``` """ @@ -188,9 +196,9 @@ class MultivariateNormalTriL( assert_proper_shapes=validate_args) else: # No need to validate that scale_tril is non-singular. - # LinearOperatorTriL has an assert_non_singular method that is called - # by the Bijector. - scale = linalg.LinearOperatorTriL( + # LinearOperatorLowerTriangular has an assert_non_singular + # method that is called by the Bijector. + scale = linalg.LinearOperatorLowerTriangular( scale_tril, is_non_singular=True, is_self_adjoint=False, diff --git a/tensorflow/contrib/distributions/python/ops/negative_binomial.py b/tensorflow/contrib/distributions/python/ops/negative_binomial.py index c8c396f6f80cf7f3228a75d279fff91ae15813ad..3a58df80da6c02b056f5e5a63bf41de5fc6d44a4 100644 --- a/tensorflow/contrib/distributions/python/ops/negative_binomial.py +++ b/tensorflow/contrib/distributions/python/ops/negative_binomial.py @@ -167,8 +167,8 @@ class NegativeBinomial(distribution.Distribution): def _log_unnormalized_prob(self, x): if self.validate_args: x = distribution_util.embed_check_nonnegative_integer_form(x) - return (self.total_count * math_ops.log1p(-self.probs) - + x * math_ops.log(self.probs)) + return (self.total_count * math_ops.log_sigmoid(-self.logits) + + x * math_ops.log_sigmoid(self.logits)) def _log_normalization(self, x): if self.validate_args: diff --git a/tensorflow/contrib/distributions/python/ops/poisson.py b/tensorflow/contrib/distributions/python/ops/poisson.py index 59a98e5682d5b3c053a18a19a1da0d2f320f21a6..e967dcc90d0712ffc346fb61ee67c44a6d9207cb 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson.py +++ b/tensorflow/contrib/distributions/python/ops/poisson.py @@ -60,15 +60,18 @@ class Poisson(distribution.Distribution): """ def __init__(self, - rate, + rate=None, + log_rate=None, validate_args=False, allow_nan_stats=True, name="Poisson"): """Initialize a batch of Poisson distributions. Args: - rate: Floating point tensor, the rate parameter of the - distribution(s). `rate` must be positive. + rate: Floating point tensor, the rate parameter. `rate` must be positive. + Must specify exactly one of `rate` and `log_rate`. + log_rate: Floating point tensor, the log of the rate parameter. + Must specify exactly one of `rate` and `log_rate`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -78,12 +81,32 @@ class Poisson(distribution.Distribution): result is undefined. When `False`, an exception is raised if one or more of the statistic's batch members are undefined. name: Python `str` name prefixed to Ops created by this class. + + Raises: + ValueError: if none or both of `rate`, `log_rate` are specified. + TypeError: if `rate` is not a float-type. + TypeError: if `log_rate` is not a float-type. """ parameters = locals() with ops.name_scope(name, values=[rate]): - with ops.control_dependencies([check_ops.assert_positive(rate)] if - validate_args else []): - self._rate = array_ops.identity(rate, name="rate") + if (rate is None) == (log_rate is None): + raise ValueError("Must specify exactly one of `rate` and `log_rate`.") + elif log_rate is None: + rate = ops.convert_to_tensor(rate, name="rate") + if not rate.dtype.is_floating: + raise TypeError("rate.dtype ({}) is a not a float-type.".format( + rate.dtype.name)) + with ops.control_dependencies([check_ops.assert_positive(rate)] if + validate_args else []): + self._rate = array_ops.identity(rate, name="rate") + self._log_rate = math_ops.log(rate, name="log_rate") + else: + log_rate = ops.convert_to_tensor(log_rate, name="log_rate") + if not log_rate.dtype.is_floating: + raise TypeError("log_rate.dtype ({}) is a not a float-type.".format( + log_rate.dtype.name)) + self._rate = math_ops.exp(log_rate, name="rate") + self._log_rate = ops.convert_to_tensor(log_rate, name="log_rate") super(Poisson, self).__init__( dtype=self._rate.dtype, reparameterization_type=distribution.NOT_REPARAMETERIZED, @@ -98,11 +121,16 @@ class Poisson(distribution.Distribution): """Rate parameter.""" return self._rate + @property + def log_rate(self): + """Log rate parameter.""" + return self._log_rate + def _batch_shape_tensor(self): return array_ops.shape(self.rate) def _batch_shape(self): - return self.rate.get_shape() + return self.rate.shape def _event_shape_tensor(self): return constant_op.constant([], dtype=dtypes.int32) @@ -137,7 +165,7 @@ class Poisson(distribution.Distribution): else: # For consistency with cdf, we take the floor. x = math_ops.floor(x) - return x * math_ops.log(self.rate) - math_ops.lgamma(1. + x) + return x * self.log_rate - math_ops.lgamma(1. + x) def _mean(self): return array_ops.identity(self.rate) diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py index 1c2046c7f03cb36b30d0fb1ffae42885dca42a81..8a95038a3c8eccf8a75fea79d0a62f9883b4f13a 100644 --- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py +++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops import poisson as poisson_lib from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -29,7 +30,6 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib -from tensorflow.python.ops.distributions import util as distribution_util __all__ = [ @@ -55,8 +55,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): ``` where `lambda(z) = exp(sqrt(2) scale z + loc)` and the `prob,grid` terms - are from [Gauss--Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). Note that + are from [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). Note that the second line made the substitution: `z(l) = (log(l) - loc) / (sqrt(2) scale)` which implies `lambda(z)` [above] and `dl = sqrt(2) scale lambda(z) dz` @@ -65,8 +67,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): Poisson rate parameter. Unfortunately, the non-approximate distribution lacks an analytical probability density function (pdf). Therefore the `PoissonLogNormalQuadratureCompound` class implements an approximation based - on [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). + on [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). + Note: although the `PoissonLogNormalQuadratureCompound` is approximately the Poisson-LogNormal compound distribution, it is itself a valid distribution. Viz., it possesses a `sample`, `log_prob`, `mean`, `variance`, etc. which are @@ -76,9 +81,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): The `PoissonLogNormalQuadratureCompound` approximates a Poisson-LogNormal [compound distribution]( - https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + https://en.wikipedia.org/wiki/Compound_probability_distribution). Using + variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `deg` different Poisson samples. @@ -93,7 +100,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): : d=0, ..., deg-1 } ``` - where, [`grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( + where, [e.g., `grid, w = numpy.polynomial.hermite.hermgauss(deg)`]( https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.polynomial.hermite.hermgauss.html) and `prob = w / sqrt(pi)`. @@ -106,14 +113,15 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): pln = ds.PoissonLogNormalQuadratureCompound( loc=[0., -0.5], scale=1., - quadrature_polynomial_degree=10, + quadrature_grid_and_probs=( + np.polynomial.hermite.hermgauss(deg=10)), validate_args=True) """ def __init__(self, loc, scale, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="PoissonLogNormalQuadratureCompound"): @@ -124,8 +132,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): the LogNormal prior. scale: `float`-like (batch of) scalar `Tensor`; the scale parameter of the LogNormal prior. - quadrature_polynomial_degree: Python `int`-like scalar. - Default value: 8. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -153,18 +163,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): "loc.dtype(\"{}\") does not match scale.dtype(\"{}\")".format( loc.dtype.name, scale.dtype.name)) - self._degree = quadrature_polynomial_degree - - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) - - # It should be that `sum(prob) == sqrt(pi)`, but self-normalization is - # more numerically stable. - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) + self._quadrature_grid = grid + self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=math_ops.log(self._quadrature_probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -176,7 +182,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): + np.sqrt(2.) * scale[..., array_ops.newaxis] * grid) self._distribution = poisson_lib.Poisson( - rate=math_ops.exp(self._log_rate, name="rate"), + log_rate=self._log_rate, validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -210,9 +216,14 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): return self._scale @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return array_ops.broadcast_dynamic_shape( @@ -242,10 +253,10 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution): [batch_size])), seed=distribution_util.gen_new_seed( seed, "poisson_lognormal_quadrature_compound")) - # Stride `quadrature_polynomial_degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) ids += offset rate = array_ops.gather( diff --git a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py index 699cf45a73883a49d116fa70c81a4f9ecb36e598..b6becfa9fc93f189a1a7bf7b2a7af8dc1f2e9720 100644 --- a/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py +++ b/tensorflow/contrib/distributions/python/ops/relaxed_onehot_categorical.py @@ -130,7 +130,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): temperature, logits=None, probs=None, - dtype=dtypes.float32, + dtype=None, validate_args=False, allow_nan_stats=True, name="ExpRelaxedOneHotCategorical"): @@ -150,7 +150,8 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): `N - 1` dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: float32). + dtype: The type of the event samples (default: inferred from + logits/probs). validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -163,14 +164,21 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): """ parameters = locals() with ops.name_scope(name, values=[logits, probs, temperature]): + + self._logits, self._probs = distribution_util.get_logits_and_probs( + name=name, logits=logits, probs=probs, validate_args=validate_args, + multidimensional=True) + + if dtype is None: + dtype = self._logits.dtype + if not validate_args: + temperature = math_ops.cast(temperature, dtype) + with ops.control_dependencies([check_ops.assert_positive(temperature)] if validate_args else []): self._temperature = array_ops.identity(temperature, name="temperature") self._temperature_2d = array_ops.reshape(temperature, [-1, 1], name="temperature_2d") - self._logits, self._probs = distribution_util.get_logits_and_probs( - name=name, logits=logits, probs=probs, validate_args=validate_args, - multidimensional=True) logits_shape_static = self._logits.get_shape().with_rank_at_least(1) if logits_shape_static.ndims is not None: @@ -230,7 +238,7 @@ class ExpRelaxedOneHotCategorical(distribution.Distribution): def _sample_n(self, n, seed=None): sample_shape = array_ops.concat([[n], array_ops.shape(self.logits)], 0) - logits = self.logits * array_ops.ones(sample_shape) + logits = self.logits * array_ops.ones(sample_shape, dtype=self.dtype) logits_2d = array_ops.reshape(logits, [-1, self.event_size]) # Uniform variates must be sampled from the open-interval `(0, 1)` rather # than `[0, 1)`. To do so, we use `np.finfo(self.dtype.as_numpy_dtype).tiny` @@ -368,7 +376,7 @@ class RelaxedOneHotCategorical( temperature, logits=None, probs=None, - dtype=dtypes.float32, + dtype=None, validate_args=False, allow_nan_stats=True, name="RelaxedOneHotCategorical"): @@ -388,7 +396,8 @@ class RelaxedOneHotCategorical( dimensions index into a batch of independent distributions and the last dimension represents a vector of probabilities for each class. Only one of `logits` or `probs` should be passed in. - dtype: The type of the event samples (default: float32). + dtype: The type of the event samples (default: inferred from + logits/probs). validate_args: Unused in this distribution. allow_nan_stats: Python `bool`, default `True`. If `False`, raise an exception if a statistic (e.g. mean/mode/etc...) is undefined for any diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py new file mode 100644 index 0000000000000000000000000000000000000000..b05f15771a3a94779ffddea8f16ad2fa4ea2fdd1 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py @@ -0,0 +1,217 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SinhArcsinh transformation of a distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.distributions import transformed_distribution + +__all__ = [ + "SinhArcsinh", +] + + +class SinhArcsinh(transformed_distribution.TransformedDistribution): + """The SinhArcsinh transformation of a distribution on `(-inf, inf)`. + + This distribution models a random variable, making use of + a `SinhArcsinh` transformation (which has adjustable tailweight and skew), + a rescaling, and a shift. + + The `SinhArcsinh` transformation of the Normal is described in great depth in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865). + Here we use a slightly different parameterization, in terms of `tailweight` + and `skewness`. Additionally we allow for distributions other than Normal, + and control over `scale` as well as a "shift" parameter `loc`. + + #### Mathematical Details + + Given random variable `Z`, we define the SinhArcsinh + transformation of `Z`, `Y`, parameterized by + `(loc, scale, skewness, tailweight)`, via the relation: + + ``` + Y := loc + scale * F(Z) * (2 / F_0(2)) + F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + ``` + + This distribution is similar to the location-scale transformation + `L(Z) := loc + scale * Z` in the following ways: + + * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then + `Y = L(Z)` exactly. + * `loc` is used in both to shift the result by a constant factor. + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` + `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. + Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond + `loc + 2 * scale` are the same. + + This distribution is different than `loc + scale * Z` due to the + reshaping done by `F`: + + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, the mode of `F(Z)` is "tilted" to the right. + * positive skew means positive values of `F(Z)` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|F(Z)|` become more likely. + * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`, + and a very steep drop-off in the tails. + * `tailweight > 1` leads to a distribution more peaked at the mode with + heavier tails. + + To see the argument about the tails, note that for `|Z| >> 1` and + `|Z| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. + + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, + + ``` + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). + ``` + """ + + def __init__(self, + loc, + scale, + skewness=None, + tailweight=None, + distribution=None, + validate_args=False, + allow_nan_stats=True, + name="SinhArcsinh"): + """Construct SinhArcsinh distribution on `(-inf, inf)`. + + Arguments `(loc, scale, skewness, tailweight)` must have broadcastable shape + (indexing batch dimensions). They must all have the same `dtype`. + + Args: + loc: Floating-point `Tensor`. + scale: `Tensor` of same `dtype` as `loc`. + skewness: Skewness parameter. Default is `0.0` (no skew). + tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight) + distribution: `tf.Distribution`-like instance. Distribution that is + transformed to produce this distribution. + Default is `ds.Normal(0., 1.)`. + Must be a scalar-batch, scalar-event distribution. Typically + `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is + a function of non-trainable parameters. WARNING: If you backprop through + a `SinhArcsinh` sample and `distribution` is not + `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then + the gradient will be incorrect! + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + """ + parameters = locals() + + with ops.name_scope(name, values=[loc, scale, skewness, tailweight]): + loc = ops.convert_to_tensor(loc, name="loc") + dtype = loc.dtype + scale = ops.convert_to_tensor(scale, name="scale", dtype=dtype) + tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None + skewness = 0. if skewness is None else skewness + tailweight = ops.convert_to_tensor( + tailweight, name="tailweight", dtype=dtype) + skewness = ops.convert_to_tensor(skewness, name="skewness", dtype=dtype) + + batch_shape = distribution_util.get_broadcast_shape( + loc, scale, tailweight, skewness) + + # Recall, with Z a random variable, + # Y := loc + C * F(Z), + # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) + if distribution is None: + distribution = normal.Normal( + loc=array_ops.zeros([], dtype=dtype), + scale=array_ops.ones([], dtype=dtype), + allow_nan_stats=allow_nan_stats) + else: + asserts = distribution_util.maybe_check_scalar_distribution( + distribution, dtype, validate_args) + if asserts: + loc = control_flow_ops.with_dependencies(asserts, loc) + + # Make the SAS bijector, 'F'. + f = bijectors.SinhArcsinh( + skewness=skewness, tailweight=tailweight, event_ndims=0) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) + + # Make the Affine bijector, Z --> loc + scale * Z (2 / F_0(2)) + c = 2 * scale / f_noskew.forward(ops.convert_to_tensor(2, dtype=dtype)) + affine = bijectors.Affine( + shift=loc, + scale_identity_multiplier=c, + validate_args=validate_args, + event_ndims=0) + + bijector = bijectors.Chain([affine, f]) + + super(SinhArcsinh, self).__init__( + distribution=distribution, + bijector=bijector, + batch_shape=batch_shape, + validate_args=validate_args, + name=name) + self._parameters = parameters + self._loc = loc + self._scale = scale + self._tailweight = tailweight + self._skewness = skewness + + @property + def loc(self): + """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._loc + + @property + def scale(self): + """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._scale + + @property + def tailweight(self): + """Controls the tail decay. `tailweight > 1` means faster than Normal.""" + return self._tailweight + + @property + def skewness(self): + """Controls the skewness. `Skewness > 0` means right skew.""" + return self._skewness diff --git a/tensorflow/contrib/distributions/python/ops/test_util.py b/tensorflow/contrib/distributions/python/ops/test_util.py index da7d3907acb6ac1c6c01ff739aa19fcb95fbb53d..77f2a39273dc365a4ac202d846dd2bc364655c86 100644 --- a/tensorflow/contrib/distributions/python/ops/test_util.py +++ b/tensorflow/contrib/distributions/python/ops/test_util.py @@ -25,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import histogram_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables as variables_ops __all__ = [ @@ -37,7 +38,7 @@ class DiscreteScalarDistributionTestHelpers(object): """DiscreteScalarDistributionTestHelpers.""" def run_test_sample_consistent_log_prob( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), num_threshold=int(1e3), seed=42, rtol=1e-2, atol=0.): """Tests that sample/log_prob are consistent with each other. @@ -50,7 +51,9 @@ class DiscreteScalarDistributionTestHelpers(object): are consistent. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -86,7 +89,7 @@ class DiscreteScalarDistributionTestHelpers(object): probs = math_ops.exp(dist.log_prob(edges)) probs = array_ops.reshape(probs, shape=[-1, batch_size])[:, b] - [counts_, probs_] = sess.run([counts, probs]) + [counts_, probs_] = sess_run_fn([counts, probs]) valid = counts_ > num_threshold probs_ = probs_[valid] counts_ = counts_[valid] @@ -94,7 +97,7 @@ class DiscreteScalarDistributionTestHelpers(object): rtol=rtol, atol=atol) def run_test_sample_consistent_mean_variance( - self, sess, dist, + self, sess_run_fn, dist, num_samples=int(1e5), seed=24, rtol=1e-2, atol=0.): """Tests that sample/mean/variance are consistent with each other. @@ -103,7 +106,9 @@ class DiscreteScalarDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -129,7 +134,7 @@ class DiscreteScalarDistributionTestHelpers(object): mean_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_variance, sample_stddev, @@ -186,7 +191,7 @@ class VectorDistributionTestHelpers(object): def run_test_sample_consistent_log_prob( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), radius=1., @@ -239,7 +244,9 @@ class VectorDistributionTestHelpers(object): https://en.wikipedia.org/wiki/Importance_sampling. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. The distribution must have non-zero probability of sampling every point @@ -279,33 +286,39 @@ class VectorDistributionTestHelpers(object): def monte_carlo_hypersphere_volume(dist, num_samples, radius, center): # https://en.wikipedia.org/wiki/Importance_sampling x = dist.sample(num_samples, seed=seed) + x = array_ops.identity(x) # Invalidate bijector cacheing. return math_ops.reduce_mean( math_ops.exp(-dist.log_prob(x)) * is_in_ball(x, radius, center), axis=0) - [ - batch_shape_, - actual_volume_, - sample_volume_, - ] = sess.run([ - dist.batch_shape_tensor(), - actual_hypersphere_volume( - dims=dist.event_shape_tensor()[0], - radius=radius), - monte_carlo_hypersphere_volume( - dist, - num_samples=num_samples, - radius=radius, - center=center), - ]) - + # Build graph. + with ops.name_scope( + "run_test_sample_consistent_log_prob", + values=[num_samples, radius, center] + dist._graph_parents): # pylint: disable=protected-access + batch_shape = dist.batch_shape_tensor() + actual_volume = actual_hypersphere_volume( + dims=dist.event_shape_tensor()[0], + radius=radius) + sample_volume = monte_carlo_hypersphere_volume( + dist, + num_samples=num_samples, + radius=radius, + center=center) + init_op = variables_ops.global_variables_initializer() + + # Execute graph. + sess_run_fn(init_op) + [batch_shape_, actual_volume_, sample_volume_] = sess_run_fn([ + batch_shape, actual_volume, sample_volume]) + + # Check results. self.assertAllClose(np.tile(actual_volume_, reps=batch_shape_), sample_volume_, rtol=rtol, atol=atol) def run_test_sample_consistent_mean_covariance( self, - sess, + sess_run_fn, dist, num_samples=int(1e5), seed=24, @@ -319,7 +332,9 @@ class VectorDistributionTestHelpers(object): to the same distribution. Args: - sess: Tensorflow session. + sess_run_fn: Python `callable` taking `list`-like of `Tensor`s and + returning a list of results after running one "step" of TensorFlow + computation, typically set to `sess.run`. dist: Distribution instance or object which implements `sample`, `log_prob`, `event_shape_tensor` and `batch_shape_tensor`. num_samples: Python `int` scalar indicating the number of Monte-Carlo @@ -353,7 +368,7 @@ class VectorDistributionTestHelpers(object): covariance_, variance_, stddev_ - ] = sess.run([ + ] = sess_run_fn([ sample_mean, sample_covariance, sample_variance, diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py index 448d881a0e1cb5018b38a392e4bd6be9e7198bea..92043d6a08833888c36009261addca0d14949ea8 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py +++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py @@ -23,21 +23,22 @@ import numpy as np from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.contrib.distributions.python.ops.bijectors.affine_linear_operator import AffineLinearOperator from tensorflow.contrib.linalg.python.ops import linear_operator_addition as linop_add_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_diag as linop_diag_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix as linop_full_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_identity as linop_identity_lib -from tensorflow.contrib.linalg.python.ops import linear_operator_tril as linop_tril_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape -from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import categorical as categorical_lib from tensorflow.python.ops.distributions import distribution as distribution_lib +from tensorflow.python.ops.linalg import linear_operator_diag as linop_diag_lib +from tensorflow.python.ops.linalg import linear_operator_full_matrix as linop_full_lib +from tensorflow.python.ops.linalg import linear_operator_identity as linop_identity_lib +from tensorflow.python.ops.linalg import linear_operator_lower_triangular as linop_tril_lib + +static_value = distribution_util.static_value __all__ = [ @@ -72,8 +73,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): denotes matrix multiplication. However, the non-approximate distribution does not have an analytical probability density function (pdf). Therefore the `VectorDiffeomixture` class implements an approximation based on - [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature). I.e., in + [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)). I.e., in Note: although the `VectorDiffeomixture` is approximately the `SoftmaxNormal-Distribution` compound distribution, it is itself a valid distribution. It possesses a `sample`, `log_prob`, `mean`, `covariance` which @@ -108,8 +111,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): The `VectorDiffeomixture` approximates a SoftmaxNormal-mixed ("prior") [compound distribution]( https://en.wikipedia.org/wiki/Compound_probability_distribution). - Using variable-substitution and [Gauss-Hermite quadrature]( - https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature) we can + Using variable-substitution and [numerical quadrature]( + https://en.wikipedia.org/wiki/Numerical_integration) (default: + [Gauss--Hermite quadrature]( + https://en.wikipedia.org/wiki/Gauss%E2%80%93Hermite_quadrature)) we can redefine the distribution to be a parameter-less convex combination of `K` different affine combinations of a `d` iid samples from `distribution`. @@ -140,7 +145,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): and, ```none - grid, weight = np.polynomial.hermite.hermgauss(quadrature_polynomial_degree) + grid, weight = np.polynomial.hermite.hermgauss(quadrature_size) prob[k] = weight[k] / sqrt(pi) lambda[k; i] = sigmoid(mix_loc[k] + sqrt(2) mix_scale[k] grid[i]) ``` @@ -184,7 +189,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and # another with mix_loc=[1]. In both cases, `K=2` and the affine @@ -218,7 +223,7 @@ class VectorDiffeomixture(distribution_lib.Distribution): distribution, loc=None, scale=None, - quadrature_polynomial_degree=8, + quadrature_grid_and_probs=None, validate_args=False, allow_nan_stats=True, name="VectorDiffeomixture"): @@ -247,7 +252,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): `k`-th element represents the `scale` used for the `k`-th affine transformation. `LinearOperator`s must have shape `[B1, ..., Bb, d, d]`, `b >= 0`, i.e., characterizes `b`-batches of `d x d` matrices - quadrature_polynomial_degree: Python `int`-like scalar. + quadrature_grid_and_probs: Python pair of `float`-like `Tensor`s + representing the sample points and the corresponding (possibly + normalized) weight. When `None`, defaults to: + `np.polynomial.hermite.hermgauss(deg=8)`. validate_args: Python `bool`, default `False`. When `True` distribution parameters are checked for validity despite possibly degrading runtime performance. When `False` invalid inputs may silently render incorrect @@ -261,7 +269,8 @@ class VectorDiffeomixture(distribution_lib.Distribution): Raises: ValueError: if `not scale or len(scale) < 2`. ValueError: if `len(loc) != len(scale)` - ValueError: if `quadrature_polynomial_degree < 1`. + ValueError: if `quadrature_grid_and_probs is not None` and + `len(quadrature_grid_and_probs[0]) != len(quadrature_grid_and_probs[1])` ValueError: if `validate_args` and any not scale.is_positive_definite. TypeError: if any scale.dtype != scale[0].dtype. TypeError: if any loc.dtype != scale[0].dtype. @@ -306,12 +315,6 @@ class VectorDiffeomixture(distribution_lib.Distribution): name="endpoint_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip(loc, scale))] - if quadrature_polynomial_degree < 1: - raise ValueError("quadrature_polynomial_degree={} " - "is not at least 1".format( - quadrature_polynomial_degree)) - self._degree = quadrature_polynomial_degree - # TODO(jvdillon): Remove once we support k-mixtures. # We make this assertion here because otherwise `grid` would need to be a # vector not a scalar. @@ -319,17 +322,17 @@ class VectorDiffeomixture(distribution_lib.Distribution): raise NotImplementedError("Currently only bimixtures are supported; " "len(scale)={} is not 2.".format(len(scale))) - grid, prob = np.polynomial.hermite.hermgauss( - deg=quadrature_polynomial_degree) - grid = grid.astype(dtype.as_numpy_dtype) - prob = prob.astype(dtype.as_numpy_dtype) - prob /= np.linalg.norm(prob, ord=1) + grid, probs = distribution_util.process_quadrature_grid_and_probs( + quadrature_grid_and_probs, dtype, validate_args) + self._quadrature_grid = grid + self._quadrature_probs = probs + self._quadrature_size = distribution_util.dimension_size(probs, axis=0) # Note: by creating the logits as `log(prob)` we ensure that # `self.mixture_distribution.logits` is equivalent to # `math_ops.log(self.mixture_distribution.probs)`. self._mixture_distribution = categorical_lib.Categorical( - logits=np.log(prob), + logits=math_ops.log(probs), validate_args=validate_args, allow_nan_stats=allow_nan_stats) @@ -338,11 +341,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): mix_scale = maybe_check_mix_param( mix_scale, "mix_scale", dtype, validate_args) - distribution_assertions = maybe_check_distribution( + asserts = distribution_util.maybe_check_scalar_distribution( distribution, dtype, validate_args) - if distribution_assertions: - mix_loc = control_flow_ops.with_dependencies( - distribution_assertions, mix_loc) + if asserts: + mix_loc = control_flow_ops.with_dependencies(asserts, mix_loc) self._distribution = distribution # shape: [B, deg] @@ -357,10 +359,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): validate_args=validate_args, name="interpolated_affine_{}".format(k)) for k, (loc_, scale_) in enumerate(zip( - interpolate_loc(quadrature_polynomial_degree, + interpolate_loc(self._quadrature_size, self._interpolate_weight, loc), - interpolate_scale(quadrature_polynomial_degree, + interpolate_scale(self._quadrature_size, self._interpolate_weight, scale)))] @@ -416,9 +418,14 @@ class VectorDiffeomixture(distribution_lib.Distribution): return self._interpolated_affine @property - def quadrature_polynomial_degree(self): - """Polynomial largest exponent used for Gauss-Hermite quadrature.""" - return self._degree + def quadrature_grid(self): + """Quadrature grid points.""" + return self._quadrature_grid + + @property + def quadrature_probs(self): + """Quadrature normalized weights.""" + return self._quadrature_probs def _batch_shape_tensor(self): return self._batch_shape_ @@ -454,10 +461,10 @@ class VectorDiffeomixture(distribution_lib.Distribution): seed=distribution_util.gen_new_seed( seed, "vector_diffeomixture")) - # Stride `self._degree` for `batch_size` number of times. + # Stride `quadrature_size` for `batch_size` number of times. offset = math_ops.range(start=0, - limit=batch_size * self._degree, - delta=self._degree, + limit=batch_size * self._quadrature_size, + delta=self._quadrature_size, dtype=ids.dtype) weight = array_ops.gather( @@ -672,43 +679,6 @@ def maybe_check_mix_param(param, name, expected_base_dtype, validate_args): return param -def maybe_check_distribution(distribution, expected_base_dtype, validate_args): - """Helper which checks validity of `distribution` init arg.""" - if distribution.dtype != expected_base_dtype: - raise TypeError("dtype mismatch; " - "distribution.dtype=\"{}\" is not \"{}\"".format( - distribution.dtype.name, expected_base_dtype.name)) - - # Although `reparameterization_type` is a static property, we guard it by - # `validate_args`. This allows users to use a `distribution` which is not - # reparameterized itself. However, we tacitly assume that although the - # distribution is not reparameterized, it only depends on non-trainable - # variables. - if validate_args and (distribution.reparameterization_type - != distribution_lib.FULLY_REPARAMETERIZED): - raise ValueError("Base distribution should be reparameterized or be " - "a function of non-trainable variables; " - "distribution.reparameterization_type = \"{}\" " - "!= \"FULLY_REPARAMETERIZED\".".format( - distribution.reparameterization_type)) - with ops.name_scope(name="check_distribution"): - assertions = [] - def check_is_scalar(is_scalar, name): - is_scalar_ = static_value(is_scalar) - if is_scalar_ is not None: - if not is_scalar_: - raise ValueError("distribution must be scalar; " - "distribution.{}=False is not True".format(name)) - elif validate_args: - assertions.append(check_ops.assert_equal( - is_scalar, True, - message=("distribution must be scalar; " - "distribution.{}=False is not True".format(name)))) - check_is_scalar(distribution.is_scalar_event(), "is_scalar_event") - check_is_scalar(distribution.is_scalar_batch(), "is_scalar_batch") - return assertions - - def determine_batch_event_shapes(mix_loc, mix_scale, endpoint_affine): """Helper to infer batch_shape and event_shape.""" with ops.name_scope(name="determine_batch_event_shapes"): @@ -809,8 +779,8 @@ def linop_scale(w, op): is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, is_positive_definite=op.is_positive_definite) - if isinstance(op, linop_tril_lib.LinearOperatorTriL): - return linop_tril_lib.LinearOperatorTriL( + if isinstance(op, linop_tril_lib.LinearOperatorLowerTriangular): + return linop_tril_lib.LinearOperatorLowerTriangular( tril=w[..., array_ops.newaxis, array_ops.newaxis] * op.to_dense(), is_non_singular=op.is_non_singular, is_self_adjoint=op.is_self_adjoint, @@ -819,11 +789,6 @@ def linop_scale(w, op): "Unsupported Linop type ({})".format(type(op).__name__)) -def static_value(x): - """Returns the static value of a `Tensor` or `None`.""" - return tensor_util.constant_value(ops.convert_to_tensor(x)) - - def concat_vectors(*args): """Concatenates input vectors, statically if possible.""" args_ = [static_value(x) for x in args] diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py index c88572e17fa43ac11778bdddc02484d284b6eb36..356d78b67a8107750f68f7f84d73d1231f5b2b03 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py @@ -90,7 +90,7 @@ class VectorExponentialDiag( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py index 7123165417ea010fa9da5263e429734d34df3dbd..b313a851b381e5b3a057fd17e6c2ef4eb0fc34f1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py @@ -18,7 +18,6 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -26,6 +25,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import exponential from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = ["VectorExponentialLinearOperator"] @@ -108,7 +108,7 @@ class VectorExponentialLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 2-variate VectorExponential, supported on # {(x, y) in R^2 : x > 0, y > 0}. @@ -247,7 +247,7 @@ class VectorExponentialLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense())) @@ -258,7 +258,7 @@ class VectorExponentialLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) and + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and self.scale.is_self_adjoint): return math_ops.sqrt( array_ops.matrix_diag_part(self.scale.matmul(self.scale.to_dense()))) diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py index fdee57695e4e598929396ee4c9fe9f8014ea0f8b..c7abdbb4caf9bee4cbd5991eb5d652f20dd0f8d1 100644 --- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py +++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py @@ -20,7 +20,6 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg from tensorflow.contrib.distributions.python.ops import bijectors from tensorflow.contrib.distributions.python.ops import distribution_util from tensorflow.python.framework import ops @@ -28,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.distributions import laplace from tensorflow.python.ops.distributions import transformed_distribution +from tensorflow.python.ops.linalg import linalg __all__ = [ @@ -110,7 +110,7 @@ class VectorLaplaceLinearOperator( ```python ds = tf.contrib.distributions - la = tf.contrib.linalg + la = tf.linalg # Initialize a single 3-variate VectorLaplace with some desired covariance. mu = [1., 2, 3] @@ -126,7 +126,7 @@ class VectorLaplaceLinearOperator( # Divide scale by sqrt(2) so that the final covariance will be what we want. vla = ds.VectorLaplaceLinearOperator( loc=mu, - scale=la.LinearOperatorTriL(scale / tf.sqrt(2))) + scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2))) # Covariance agrees with cholesky(cov) parameterization. vla.covariance().eval() @@ -271,8 +271,8 @@ class VectorLaplaceLinearOperator( def _variance(self): if distribution_util.is_diagonal_scale(self.scale): return 2. * math_ops.square(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return array_ops.matrix_diag_part( 2. * self.scale.matmul(self.scale.to_dense())) else: @@ -282,8 +282,8 @@ class VectorLaplaceLinearOperator( def _stddev(self): if distribution_util.is_diagonal_scale(self.scale): return np.sqrt(2) * math_ops.abs(self.scale.diag_part()) - elif (isinstance(self.scale, linalg.LinearOperatorUDVHUpdate) - and self.scale.is_self_adjoint): + elif (isinstance(self.scale, linalg.LinearOperatorLowRankUpdate) and + self.scale.is_self_adjoint): return np.sqrt(2) * math_ops.sqrt(array_ops.matrix_diag_part( self.scale.matmul(self.scale.to_dense()))) else: diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py new file mode 100644 index 0000000000000000000000000000000000000000..544a8710709a0afb56c6ae6f36d35de892e8e420 --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py @@ -0,0 +1,265 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Multi-dimensional (Vector) SinhArcsinh transformation of a distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.distributions.python.ops import bijectors +from tensorflow.contrib.distributions.python.ops import distribution_util +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops.distributions import normal +from tensorflow.python.ops.distributions import transformed_distribution + +__all__ = [ + "VectorSinhArcsinhDiag", +] + + +class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution): + """The (diagonal) SinhArcsinh transformation of a distribution on `R^k`. + + This distribution models a random vector `Y = (Y1,...,Yk)`, making use of + a `SinhArcsinh` transformation (which has adjustable tailweight and skew), + a rescaling, and a shift. + + The `SinhArcsinh` transformation of the Normal is described in great depth in + [Sinh-arcsinh distributions](https://www.jstor.org/stable/27798865). + Here we use a slightly different parameterization, in terms of `tailweight` + and `skewness`. Additionally we allow for distributions other than Normal, + and control over `scale` as well as a "shift" parameter `loc`. + + #### Mathematical Details + + Given iid random vector `Z = (Z1,...,Zk)`, we define the VectorSinhArcsinhDiag + transformation of `Z`, `Y`, parameterized by + `(loc, scale, skewness, tailweight)`, via the relation (with `@` denoting + matrix multiplication): + + ``` + Y := loc + scale @ F(Z) * (2 / F_0(2)) + F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + ``` + + This distribution is similar to the location-scale transformation + `L(Z) := loc + scale @ Z` in the following ways: + + * If `skewness = 0` and `tailweight = 1` (the defaults), `F(Z) = Z`, and then + `Y = L(Z)` exactly. + * `loc` is used in both to shift the result by a constant factor. + * The multiplication of `scale` by `2 / F_0(2)` ensures that if `skewness = 0` + `P[Y - loc <= 2 * scale] = P[L(Z) - loc <= 2 * scale]`. + Thus it can be said that the weights in the tails of `Y` and `L(Z)` beyond + `loc + 2 * scale` are the same. + + This distribution is different than `loc + scale @ Z` due to the + reshaping done by `F`: + + * Positive (negative) `skewness` leads to positive (negative) skew. + * positive skew means, the mode of `F(Z)` is "tilted" to the right. + * positive skew means positive values of `F(Z)` become more likely, and + negative values become less likely. + * Larger (smaller) `tailweight` leads to fatter (thinner) tails. + * Fatter tails mean larger values of `|F(Z)|` become more likely. + * `tailweight < 1` leads to a distribution that is "flat" around `Y = loc`, + and a very steep drop-off in the tails. + * `tailweight > 1` leads to a distribution more peaked at the mode with + heavier tails. + + To see the argument about the tails, note that for `|Z| >> 1` and + `|Z| >> (|skewness| * tailweight)**tailweight`, we have + `Y approx 0.5 Z**tailweight e**(sign(Z) skewness * tailweight)`. + + To see the argument regarding multiplying `scale` by `2 / F_0(2)`, + + ``` + P[(Y - loc) / scale <= 2] = P[F(Z) * (2 / F_0(2)) <= 2] + = P[F(Z) <= F_0(2)] + = P[Z <= 2] (if F = F_0). + ``` + """ + + def __init__(self, + loc=None, + scale_diag=None, + scale_identity_multiplier=None, + skewness=None, + tailweight=None, + distribution=None, + validate_args=False, + allow_nan_stats=True, + name="MultivariateNormalLinearOperator"): + """Construct VectorSinhArcsinhDiag distribution on `R^k`. + + The arguments `scale_diag` and `scale_identity_multiplier` combine to + define the diagonal `scale` referred to in this class docstring: + + ```none + scale = diag(scale_diag + scale_identity_multiplier * ones(k)) + ``` + + The `batch_shape` is the broadcast shape between `loc` and `scale` + arguments. + + The `event_shape` is given by last dimension of the matrix implied by + `scale`. The last dimension of `loc` (if provided) must broadcast with this + + Additional leading dimensions (if any) will index batches. + + Args: + loc: Floating-point `Tensor`. If this is set to `None`, `loc` is + implicitly `0`. When specified, may have shape `[B1, ..., Bb, k]` where + `b >= 0` and `k` is the event size. + scale_diag: Non-zero, floating-point `Tensor` representing a diagonal + matrix added to `scale`. May have shape `[B1, ..., Bb, k]`, `b >= 0`, + and characterizes `b`-batches of `k x k` diagonal matrices added to + `scale`. When both `scale_identity_multiplier` and `scale_diag` are + `None` then `scale` is the `Identity`. + scale_identity_multiplier: Non-zero, floating-point `Tensor` representing + a scale-identity-matrix added to `scale`. May have shape + `[B1, ..., Bb]`, `b >= 0`, and characterizes `b`-batches of scale + `k x k` identity matrices added to `scale`. When both + `scale_identity_multiplier` and `scale_diag` are `None` then `scale` + is the `Identity`. + skewness: Skewness parameter. floating-point `Tensor` with shape + broadcastable with `event_shape`. + tailweight: Tailweight parameter. floating-point `Tensor` with shape + broadcastable with `event_shape`. + distribution: `tf.Distribution`-like instance. Distribution from which `k` + iid samples are used as input to transformation `F`. Default is + `ds.Normal(0., 1.)`. + Must be a scalar-batch, scalar-event distribution. Typically + `distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is + a function of non-trainable parameters. WARNING: If you backprop through + a VectorSinhArcsinhDiag sample and `distribution` is not + `FULLY_REPARAMETERIZED` yet is a function of trainable variables, then + the gradient will be incorrect! + validate_args: Python `bool`, default `False`. When `True` distribution + parameters are checked for validity despite possibly degrading runtime + performance. When `False` invalid inputs may silently render incorrect + outputs. + allow_nan_stats: Python `bool`, default `True`. When `True`, + statistics (e.g., mean, mode, variance) use the value "`NaN`" to + indicate the result is undefined. When `False`, an exception is raised + if one or more of the statistic's batch members are undefined. + name: Python `str` name prefixed to Ops created by this class. + + Raises: + ValueError: if at most `scale_identity_multiplier` is specified. + """ + parameters = locals() + + with ops.name_scope( + name, + values=[ + loc, scale_diag, scale_identity_multiplier, skewness, tailweight + ]): + loc = ops.convert_to_tensor(loc, name="loc") if loc is not None else loc + tailweight = 1. if tailweight is None else tailweight + has_default_skewness = skewness is None + skewness = 0. if skewness is None else skewness + + # Recall, with Z a random variable, + # Y := loc + C * F(Z), + # F(Z) := Sinh( (Arcsinh(Z) + skewness) * tailweight ) + # F_0(Z) := Sinh( Arcsinh(Z) * tailweight ) + # C := 2 * scale / F_0(2) + + # Construct shapes and 'scale' out of the scale_* and loc kwargs. + # scale_linop is only an intermediary to: + # 1. get shapes from looking at loc and the two scale args. + # 2. combine scale_diag with scale_identity_multiplier, which gives us + # 'scale', which in turn gives us 'C'. + scale_linop = distribution_util.make_diag_scale( + loc=loc, + scale_diag=scale_diag, + scale_identity_multiplier=scale_identity_multiplier, + validate_args=False, + assert_positive=False) + batch_shape, event_shape = distribution_util.shapes_from_loc_and_scale( + loc, scale_linop) + # scale_linop.diag_part() is efficient since it is a diag type linop. + scale_diag_part = scale_linop.diag_part() + dtype = scale_diag_part.dtype + + if distribution is None: + distribution = normal.Normal( + loc=array_ops.zeros([], dtype=dtype), + scale=array_ops.ones([], dtype=dtype), + allow_nan_stats=allow_nan_stats) + else: + asserts = distribution_util.maybe_check_scalar_distribution( + distribution, dtype, validate_args) + if asserts: + scale_diag_part = control_flow_ops.with_dependencies( + asserts, scale_diag_part) + + # Make the SAS bijector, 'F'. + skewness = ops.convert_to_tensor(skewness, dtype=dtype, name="skewness") + tailweight = ops.convert_to_tensor( + tailweight, dtype=dtype, name="tailweight") + f = bijectors.SinhArcsinh( + skewness=skewness, tailweight=tailweight, event_ndims=1) + if has_default_skewness: + f_noskew = f + else: + f_noskew = bijectors.SinhArcsinh( + skewness=skewness.dtype.as_numpy_dtype(0.), + tailweight=tailweight, event_ndims=0) + + # Make the Affine bijector, Z --> loc + C * Z. + c = 2 * scale_diag_part / f_noskew.forward( + ops.convert_to_tensor(2, dtype=dtype)) + affine = bijectors.Affine( + shift=loc, scale_diag=c, validate_args=validate_args, event_ndims=1) + + bijector = bijectors.Chain([affine, f]) + + super(VectorSinhArcsinhDiag, self).__init__( + distribution=distribution, + bijector=bijector, + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args, + name=name) + self._parameters = parameters + self._loc = loc + self._scale = scale_linop + self._tailweight = tailweight + self._skewness = skewness + + @property + def loc(self): + """The `loc` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._loc + + @property + def scale(self): + """The `LinearOperator` `scale` in `Y := loc + scale @ F(Z) * (2 / F(2)).""" + return self._scale + + @property + def tailweight(self): + """Controls the tail decay. `tailweight > 1` means faster than Normal.""" + return self._tailweight + + @property + def skewness(self): + """Controls the skewness. `Skewness > 0` means right skew.""" + return self._skewness diff --git a/tensorflow/contrib/distributions/python/ops/wishart.py b/tensorflow/contrib/distributions/python/ops/wishart.py index 9d30ce67197ebdeefc69d9b9979fdad4797bb183..e4ac65012b9c7e3ed5ada3ed75020f3905740156 100644 --- a/tensorflow/contrib/distributions/python/ops/wishart.py +++ b/tensorflow/contrib/distributions/python/ops/wishart.py @@ -251,8 +251,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM) where M is the complexity of the operator solving a # vector system. E.g., for LinearOperatorDiag, each matmul is O(k**2), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each matmul is - # O(k^3) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each matmul is O(k^3) so this step has complexity O(nbk^3). x = self.scale_operator.matmul(x) # Undo make batch-op ready. @@ -307,8 +307,8 @@ class _WishartLinearOperator(distribution.Distribution): # Complexity: O(nbM*k) where M is the complexity of the operator solving # a vector system. E.g., for LinearOperatorDiag, each solve is O(k), so - # this complexity is O(nbk**2). For LinearOperatorTriL, each solve is - # O(k**2) so this step has complexity O(nbk^3). + # this complexity is O(nbk**2). For LinearOperatorLowerTriangular, + # each solve is O(k**2) so this step has complexity O(nbk^3). scale_sqrt_inv_x_sqrt = self.scale_operator.solve( scale_sqrt_inv_x_sqrt) @@ -544,7 +544,7 @@ class WishartCholesky(_WishartLinearOperator): super(WishartCholesky, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=scale, is_non_singular=True, is_positive_definite=True, @@ -655,7 +655,7 @@ class WishartFull(_WishartLinearOperator): ] if validate_args else [], chol) super(WishartFull, self).__init__( df=df, - scale_operator=linalg.LinearOperatorTriL( + scale_operator=linalg.LinearOperatorLowerTriangular( tril=chol, is_non_singular=True, is_positive_definite=True, diff --git a/tensorflow/contrib/eager/README.md b/tensorflow/contrib/eager/README.md index a4a3af08cf27d20147539cd0dde1f5e3a9d46918..ae4b07799f5c123b68529443a1765fbfbac05492 100644 --- a/tensorflow/contrib/eager/README.md +++ b/tensorflow/contrib/eager/README.md @@ -1,15 +1,78 @@ -TensorFlow has many kernels for doing (deep) learning and data manipulation. -There are typically assembled into computational graphs which can run -efficiently in a variety of environments. +# TensorFlow Eager Execution -We are exploring an alternative interaction, where kernels are invoked -immediately and call this "eager execution". We are hoping to retain the -benefits of graphs while improving usability with benefits like: +> *WARNING*: This is a preview/pre-alpha version. The API and performance +> characteristics are subject to change. -- Immediate error messages and easier debugging -- Flexibility to use Python datastructures and control flow -- Reduced boilerplate +Eager execution is an experimental interface to TensorFlow that provides an +imperative programming style (à la [NumPy](http://www.numpy.org)). When you +enable eager execution, TensorFlow operations execute immediately; you do not +execute a pre-constructed graph with +[`Session.run()`](https://www.tensorflow.org/api_docs/python/tf/Session). -Eager execution is under active development. -There are not many developer-facing materials yet, but stay tuned for updates -in this directory. +For example, consider a simple computation in TensorFlow: + +```python +x = tf.placeholder(tf.float32, shape=[1, 1]) +m = tf.matmul(x, x) + +with tf.Session() as sess: + print(sess.run(m, feed_dict={x: [[2.]]})) + +# Will print [[4.]] +``` + +Eager execution makes this much simpler: + +```python +x = [[2.]] +m = tf.matmul(x, x) + +print(m) +``` + +## Caveats + +This feature is in early stages and work remains to be done in terms of smooth +support for distributed and multi-GPU training and CPU performance. + +- [Known issues](https://github.com/tensorflow/tensorflow/issues?q=is%3Aissue%20is%3Aopen%20label%3Acomp%3Aeager) +- Feedback is welcome, please consider + [filing an issue](https://github.com/tensorflow/tensorflow/issues/new) to provide it. + +## Installation + +Since eager execution is not yet part of a TensorFlow release, using it requires +either [building from source](https://www.tensorflow.org/install/install_sources) +or the latest nightly builds. The nightly builds are available as: + +- [`pip` packages](https://github.com/tensorflow/tensorflow/blob/master/README.md#installation) and + +- [docker](https://hub.docker.com/r/tensorflow/tensorflow/) images. + +For example, to run the latest nightly docker image: + +```sh +# If you have a GPU, use https://github.com/NVIDIA/nvidia-docker +nvidia-docker pull tensorflow/tensorflow:nightly-gpu +nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu + +# If you do not have a GPU, use the CPU-only image +docker pull tensorflow/tensorflow:nightly +docker run -it -p 8888:8888 tensorflow/tensorflow:nightly +``` + +And then visit http://localhost:8888 in your browser for a Jupyter notebook +environment. Try out the notebooks below. + +## Documentation + +For an introduction to eager execution in TensorFlow, see: + +- [User Guide](python/g3doc/guide.md) +- Notebook: [Basic Usage](python/examples/notebooks/1_basics.ipynb) +- Notebook: [Gradients](python/examples/notebooks/2_gradients.ipynb) +- Notebook: [Importing Data](python/examples/notebooks/3_datasets.ipynb) + +## Changelog + +- 2017/10/31: Initial preview release. diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD index 10c276826df4f5388aaf92b3f892d74563fcc90c..2b84bc2e9b7453fac99ea2becc328ca854cf555d 100644 --- a/tensorflow/contrib/eager/python/BUILD +++ b/tensorflow/contrib/eager/python/BUILD @@ -2,16 +2,24 @@ licenses(["notice"]) # Apache 2.0 package(default_visibility = ["//tensorflow:internal"]) -load("//tensorflow:tensorflow.bzl", "py_test", "cuda_py_test") +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") py_library( name = "tfe", srcs = ["tfe.py"], srcs_version = "PY2AND3", + visibility = ["//visibility:public"], deps = [ ":datasets", + ":evaluator", + ":metrics", + ":network", ":saver", + ":summary_writer", "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:numerics", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:util", "//tensorflow/python/eager:backprop", @@ -29,9 +37,11 @@ cuda_py_test( additional_deps = [ ":tfe", "//tensorflow/python:array_ops", + "//tensorflow/python:metrics", "//tensorflow/python:math_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:platform_test", + "//tensorflow/python:summary", ], ) @@ -41,8 +51,10 @@ py_library( srcs_version = "PY2AND3", visibility = ["//tensorflow:internal"], deps = [ + "//tensorflow/python:array_ops", "//tensorflow/python:dataset_ops_gen", "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", "//tensorflow/python:resource_variable_ops", "//tensorflow/python/data/util:nest", "//tensorflow/python/eager:context", @@ -55,10 +67,12 @@ py_test( srcs_version = "PY2AND3", deps = [ ":datasets", - "//tensorflow/contrib/data", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python/data", "//tensorflow/python/eager:test", - "//third_party/py/numpy", ], ) @@ -67,7 +81,11 @@ py_library( srcs = ["saver.py"], srcs_version = "PY2AND3", deps = [ + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", + "//tensorflow/python/eager:context", ], ) @@ -79,11 +97,151 @@ cuda_py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform_test", + "//tensorflow/python/eager:graph_callable", + "//tensorflow/python/eager:test", "//tensorflow/python:variables", ], ) +py_library( + name = "summary_writer", + srcs = ["summary_writer.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/summary:gen_summary_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary_op_util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + ], +) + +cuda_py_test( + name = "summary_writer_test", + srcs = ["summary_writer_test.py"], + additional_deps = [ + ":summary_writer", + "//third_party/py/numpy", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:constant_op", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "metrics", + srcs = [ + "metrics.py", + "metrics_impl.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", + ], +) + +py_test( + name = "metrics_test", + srcs = ["metrics_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":metrics", + "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + "//tensorflow/python:training", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "evaluator", + srcs = [ + "evaluator.py", + ], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":datasets", + ":metrics", + "//tensorflow/contrib/summary:summary_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", + "@six_archive//:six", + ], +) + +py_test( + name = "evaluator_test", + srcs = ["evaluator_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":evaluator", + ":metrics", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:test", + ], +) + +py_library( + name = "network", + srcs = ["network.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/python:layers_base", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:util", + ], +) + +py_test( + name = "network_test", + srcs = ["network_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":network", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:layers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:variable_scope", + "//tensorflow/python/eager:test", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py index 7e353eb3f44e98264682597e8c115f91e2b412f9..98e6983658aed77277d87915ff26a8c676224503 100644 --- a/tensorflow/contrib/eager/python/datasets.py +++ b/tensorflow/contrib/eager/python/datasets.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for tf.contrib.data when eager execution is enabled.""" +"""Iteration over tf.data.Datasets when eager execution is enabled.""" from __future__ import absolute_import from __future__ import division @@ -23,6 +23,8 @@ import threading from tensorflow.python.data.util import nest from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_dataset_ops from tensorflow.python.ops import resource_variable_ops @@ -39,20 +41,23 @@ def _iterator_shared_name(): class Iterator(object): - """An iterator producing tf.Tensor objects from a tf.contrib.data.Dataset.""" + """An iterator producing tf.Tensor objects from a tf.data.Dataset.""" def __init__(self, dataset): """Creates a new iterator over the given dataset. For example: ```python - dataset = tf.contrib.data.Dataset.range(4) + dataset = tf.data.Dataset.range(4) for x in Iterator(dataset): print(x) ``` + Tensors produced will be placed on the device on which this iterator object + was created. + Args: - dataset: A `tf.contrib.data.Dataset` object. + dataset: A `tf.data.Dataset` object. Raises: RuntimeError: When invoked without eager execution enabled. @@ -60,23 +65,25 @@ class Iterator(object): if not context.in_eager_mode(): raise RuntimeError( - "{} objects only make sense when eager execution is enabled".format( - type(self))) - ds_variant = dataset.make_dataset_resource() - self._output_types = dataset.output_types - self._flat_output_types = nest.flatten(dataset.output_types) - self._flat_output_shapes = nest.flatten(dataset.output_shapes) - self._resource = gen_dataset_ops.iterator( - container="", - shared_name=_iterator_shared_name(), - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - gen_dataset_ops.make_iterator(ds_variant, self._resource) - - def __del__(self): - if self._resource is not None: - resource_variable_ops.destroy_resource_op(self._resource) - self._resource = None + "{} objects can only be used when eager execution is enabled, use " + "tf.data.Dataset.make_iterator or " + "tf.data.Dataset.make_one_shot_iterator for graph construction". + format(type(self))) + with ops.device("/device:CPU:0"): + ds_variant = dataset._as_variant_tensor() # pylint: disable=protected-access + self._output_types = dataset.output_types + self._flat_output_types = nest.flatten(dataset.output_types) + self._flat_output_shapes = nest.flatten(dataset.output_shapes) + self._resource = gen_dataset_ops.iterator( + container="", + shared_name=_iterator_shared_name(), + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) + gen_dataset_ops.make_iterator(ds_variant, self._resource) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device="/device:CPU:0") + self._device = context.context().device_name def __iter__(self): return self @@ -87,10 +94,19 @@ class Iterator(object): def next(self): """Return the next tf.Tensor from the dataset.""" try: - ret = gen_dataset_ops.iterator_get_next( - self._resource, - output_types=self._flat_output_types, - output_shapes=self._flat_output_shapes) - return nest.pack_sequence_as(self._output_types, ret) + # TODO(ashankar): Consider removing this ops.device() contextmanager + # and instead mimic ops placement in graphs: Operations on resource + # handles execute on the same device as where the resource is placed. + with ops.device("/device:CPU:0"): + ret = gen_dataset_ops.iterator_get_next( + self._resource, + output_types=self._flat_output_types, + output_shapes=self._flat_output_shapes) except errors.OutOfRangeError: raise StopIteration + # Copies tensors from CPU to the current device if necessary. + # TODO(rohanj): This should be replaced by the mechanism to have the + # runtime's threads copy tensors to the destination device. + with ops.device(self._device): + ret = [array_ops.identity(x) for x in ret] + return nest.pack_sequence_as(self._output_types, ret) diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py index a2da6b28c6bdbfa0f6a4ed5d303aa4a36b81fc19..c924d81c9d85e638e4f35f260664c0ee7d03257e 100644 --- a/tensorflow/contrib/eager/python/datasets_test.py +++ b/tensorflow/contrib/eager/python/datasets_test.py @@ -16,10 +16,13 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.data import Dataset from tensorflow.contrib.eager.python import datasets +from tensorflow.python.data import Dataset from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import script_ops class IteratorTest(test.TestCase): @@ -69,6 +72,23 @@ class IteratorTest(test.TestCase): got2 = [x.numpy() for x in datasets.Iterator(ds)] self.assertAllEqual(got1, got2) + def testPyFunc(self): + + def my_map(inp): + return [[x + 1 for x in inp]] + + ds = Dataset.range(4).map( + lambda x: script_ops.py_func(my_map, [[x]], dtypes.int64)) + got = [x.numpy() for x in datasets.Iterator(ds)] + self.assertAllEqual([[1], [2], [3], [4]], got) + + def testTensorsPlacedOnDevice(self): + ds = Dataset.from_tensors([0., 1.]) + with ops.device(test.gpu_device_name()): + x = datasets.Iterator(ds).next() + x = math_ops.add(x, x) + self.assertAllEqual([0., 2.], x.numpy()) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/evaluator.py b/tensorflow/contrib/eager/python/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0ab02ecf7ae6025e08dde1c3ddc634db9255c1 --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator.py @@ -0,0 +1,371 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Class Evaluator holds Metrics for the duration of an evaluation run.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + +from tensorflow.contrib.eager.python import datasets +from tensorflow.contrib.eager.python import metrics +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops + + +class Evaluator(object): + """This holds and updates Metrics for the duration of a single eval run. + + Usage: + evaluator = my_model.evaluator() # or MyEvaluator(my_model) + for example_batch in ...: + evaluator(example_batch) + results = evaluator.all_metric_results(optional_summary_logdir) + + Or, if you are getting your examples from a tf.data.Dataset, you can use + the evaluate_on_dataset() method. + + Implementers of Evaluators should + (a) Call `track_metric()` and/or `track_evaluator()` in __init__(). + (b) Override the `call()` method. It will be passed the output of the + model's `eval_data()` method, and should call its contained metrics + (treating them as callables) and any child Evaluators (using their + call() method to avoid calling eval_data() again). + + Args: + model: A `Model` object with an `eval_data()` method. + """ + + def __init__(self, model): + self._model = model + self._metrics = {} + self._evaluators = {} + if context.in_graph_mode(): + self.call = function.defun(self.call) + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + """Update metrics with a minibatch of input examples. + + Args: + *args: + **kwargs: Arguments representing an input mini-batch of examples to + pass to self.model.eval_data(). + + Returns: + The op to execute or None if executing eagerly. + """ + return self.call(self._model.eval_data(*args, **kwargs)) + + def init_variables(self): + """Return an op for initializing all contained uninitialized variables. + + Only for graph execution. Should be called after variables are created + in the first execution of __call__(). + + Returns: + An op. + + Raises: + RuntimeError: if eager execution is enabled. + + @compatibility(eager) + Only for graph execution. + @end_compatibility + """ + if context.in_eager_mode(): + raise RuntimeError("Evaluator.init_variables() not needed when " + "eager execution is enabled.") + return control_flow_ops.group([m.init_variables() for _, m in self.metrics]) + + def all_metric_results(self, summary_logdir=None): + """Computes results for all contained metrics. + + Args: + summary_logdir: An optional string. If specified, metric results + will be written as summaries to this directory. + + Returns: + A `dict` mapping string names to tensors. + """ + if summary_logdir is None: + with summary_ops.never_record_summaries(): + return self._all_metric_results() + else: + def f(): + with summary_ops.create_summary_file_writer( + summary_logdir).as_default(), summary_ops.always_record_summaries(): + return self._all_metric_results() + if context.in_eager_mode(): + return f() + else: + return function.defun(f)() + + def _all_metric_results(self): + """Implementation of `all_metric_results` in the summary context.""" + results = {} + for name, metric in six.iteritems(self._metrics): + results[name] = metric.result() + for prefix, evaluator in six.iteritems(self._evaluators): + for name, metric in six.iteritems(evaluator._metrics): # pylint: disable=protected-access + results[prefix + "/" + name] = metric.result() + return results + + def evaluate_on_dataset(self, dataset, *args, **kwargs): + """Convenience method for performing an eval on a Dataset. + + Args: + dataset: Dataset object with the input data to evaluate on. + *args: + **kwargs: Optional additional arguments to __call__(), except + `summary_logdir`: if specified, metrics will be written as summaries + to this directory. + + Returns: + @compatibility(eager) + When eager execution is enabled, this returns the result of performing + an evaluation as a dictionary. With graph execution, this returns a tuple + (init_op, call_op, results_op) which may be executed using this code: + ```python + sess.run(init_op) + try: + while True: + sess.run(call_op) + except tf.errors.OutOfRangeError: + pass + return sess.run(results_op) # A dictionary + + # equivalently: + return evaluator.run_evaluation(init_op, call_op, results_op, sess=sess) + ``` + @end_compatibility + """ + summary_logdir = kwargs.pop("summary_logdir", None) + if context.in_graph_mode(): + call_op = self.__call__(dataset.make_one_shot_iterator().get_next(), + *args, **kwargs) + init_op = self.init_variables() + results_op = self.all_metric_results(summary_logdir) + return (init_op, call_op, results_op) + # Eager case + for example in datasets.Iterator(dataset): + self.__call__(example, *args, **kwargs) + return self.all_metric_results(summary_logdir) + + @staticmethod + def run_evaluation(init_op, call_op, results_op, sess=None): + """Convenience method for running the ops returned by evaluate_on_dataset. + + Args: + init_op: An op that initializes/resets evaluation state. + call_op: An op that updates evaluation state on a mini-batch of examples. + Must generate an tf.errors.OutOfRangeError when done. + results_op: A dictionary of tensors that compute the final evaluation + results from the evaulation state. + sess: The Session to run the evaluation in. Defaults to the default + Session. + + Returns: + A dictionary of values, parallel to results_op. + + Raises: + RuntimeError: if eager execution is enabled. + + @compatibility(eager) + Only for graph execution. + @end_compatibility + """ + if context.in_eager_mode(): + raise RuntimeError("Evaluator.run_evaluation() not supported when " + "eager execution is enabled.") + sess = sess or ops.get_default_session() + sess.run(init_op) + try: + while True: + sess.run(call_op) + except errors_impl.OutOfRangeError: + pass + return sess.run(results_op) + + # ---- To be implemented by descendants --- + def call(self, eval_data): + """Update metrics using the output of self.model. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + eval_data: The output of self.model.eval_data() on a mini-batch of + examples. + """ + raise NotImplementedError("Evaluators must define a call member function.") + + # ---- For use by descendants --- + @property + def model(self): + return self._model + + def track_metric(self, metric): + """Add a Metric to be tracked. + + Metrics can only be tracked by one `Evaluator`. Metrics must be + tracked or they will not appear in `all_metric_results()`. + + Args: + metric: A `Metric` object. + + Returns: + The `metric` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `metric` is not of the correct type. + ValueError: If there is a name collision between Metrics or `metric` + has already been added to another `Evaluator`. + """ + if not hasattr(self, "_metrics"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding metrics") + if not isinstance(metric, metrics.Metric): + raise TypeError( + "Evaluator.track_metric() passed type %s, not a tfe.metrics.Metric" % + (type(metric),)) + if metric.name in self._metrics: + if metric is self._metrics[metric.name]: + return metric + raise ValueError( + "Attempt to add two Metrics with the name '%s' to the same Evaluator " + "'%s'" % (metric.name, self.name)) + # pylint: disable=protected-access + if hasattr(metric, "_added_to_an_evaluator"): + raise ValueError("Metric %s already added to Evaluator %s" % + (metric.name, metric._added_to_an_evaluator)) + metric._added_to_an_evaluator = self.__class__.__name__ + # pylint: enable=protected-access + self._metrics[metric.name] = metric + return metric + + def track_evaluator(self, prefix, evaluator): + """Add a contained `Evaluator`. + + This is for delegating to another `Evaluator`, e.g. for when you have a + model with multiple heads. Users should manually invoke the child + `Evaluator`'s `call` method from their `call` method. + + Args: + prefix: A string. Metrics from `evaluator` are exported with this + prefix and a '/'. + evaluator: An `Evaluator` object. + + Returns: + The value of `evaluator` passed into this function. + + Raises: + RuntimeError: If called before __init__. + TypeError: If `evaluator` is not of the correct type. + ValueError: If an `Evaluator` has already been added with that `prefix`. + """ + if not hasattr(self, "_evaluators"): + raise RuntimeError( + "Need to call Evaluator.__init__ before adding evaluators") + if not isinstance(evaluator, Evaluator): + raise TypeError( + "Evaluator.track_evaluator() passed type %s, not a tfe.Evaluator." % + (type(evaluator),)) + if prefix in self._evaluators: + if evaluator is self._evaluators[prefix]: + return evaluator + raise RuntimeError( + "Attempt to add two Evaluators with the same prefix '%s'." % prefix) + self._evaluators[prefix] = evaluator + return evaluator + + @property + def metric_variables(self): + v = [] + for metric in six.itervalues(self._metrics): + v += metric.variables + for evaluator in six.itervalues(self._evaluators): + v += evaluator.metric_variables + return v + + @property + def metrics(self): + """Returns a list of (prefix, metric) pairs.""" + m = [] + for metric in six.itervalues(self._metrics): + m.append(("", metric)) + for prefix, evaluator in six.iteritems(self._evaluators): + m += [(prefix + "/" + p, m) for p, m in evaluator.metrics] + return m + + +class SparseSoftmaxEvaluator(Evaluator): + """Evaluator for a sparse softmax model. + + Computes a standard set of metrics for single-label, multi-class + models. + + Args: + model: A `SparseSoftmaxModel` object or a `Model` whose `eval_data()` + method produces a `dict` containing values for the loss, true + label, predicted class, and optional weights. + loss_key: Optional key for looking up the value of the loss in the + `eval_data()` dict. Defaults to "loss". + label_key: Optional key for looking up the value of the label in the + `eval_data()` dict. Defaults to "label". + predicted_class_key: Optional key for looking up the value of the + predicted class in the `eval_data()` dict. Defaults to "predicted_class". + weights_key: Optional key for looking up the value of the weights + in the `eval_data()` dict. Defaults to "weights". Note that weights + are optional, and default to 1 if not present in `eval_data`. + """ + + def __init__(self, model, loss_key="loss", label_key="label", + predicted_class_key="predicted_class", weights_key="weights"): + super(SparseSoftmaxEvaluator, self).__init__(model) + # TODO(josh11b): Expand this to include everything from the standard + # SparseSoftmax Head. + self.avg_loss = self.track_metric(metrics.Mean("Avg Loss")) + self.accuracy = self.track_metric(metrics.Accuracy()) + self.loss_key = loss_key + self.label_key = label_key + self.predicted_class_key = predicted_class_key + self.weights_key = weights_key + + def call(self, eval_data): + """Update metrics for `eval_data` dict (described above).""" + weights = eval_data.get(self.weights_key, None) + if weights is None: + self.avg_loss(eval_data[self.loss_key]) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key]) + else: + self.avg_loss(eval_data[self.loss_key], weights=weights) + self.accuracy(eval_data[self.label_key], + eval_data[self.predicted_class_key], + weights=weights) diff --git a/tensorflow/contrib/eager/python/evaluator_test.py b/tensorflow/contrib/eager/python/evaluator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..02f82cb216983accc7bc2dfa20cbb1ee0b8d8d26 --- /dev/null +++ b/tensorflow/contrib/eager/python/evaluator_test.py @@ -0,0 +1,195 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for class Evaluator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.eager.python import evaluator + +from tensorflow.contrib.eager.python import metrics +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables +from tensorflow.python.training import training_util + + +class IdentityModel(object): + + def eval_data(self, d): + return d + + +class PrefixLModel(object): + + def eval_data(self, d): + return {"l_" + key: d[key] for key in d} + + +class SimpleEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(SimpleEvaluator, self).__init__(model) + self.mean = self.track_metric(metrics.Mean("mean")) + + def call(self, eval_data): + self.mean(eval_data) + + +class DelegatingEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(DelegatingEvaluator, self).__init__(model) + self.sub = self.track_evaluator("inner", SimpleEvaluator(model)) + self.mean = self.track_metric(metrics.Mean("outer-mean")) + + def call(self, eval_data): + # Keys here come from PrefixLModel, which adds "l_". + self.mean(eval_data["l_outer"]) + self.sub.call(eval_data["l_inner"]) + + +# pylint: disable=not-callable +class EvaluatorTest(test.TestCase): + + def testSimple(self): + e = SimpleEvaluator(IdentityModel()) + e(3.0) + e([5.0, 7.0, 9.0]) + results = e.all_metric_results() + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testWriteSummaries(self): + e = SimpleEvaluator(IdentityModel()) + e(3.0) + e([5.0, 7.0, 9.0]) + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + + e.all_metric_results(logdir) + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 6.0) + + def testComposition(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + e({"inner": 4.0, "outer": 1000.0}) + results = e.all_metric_results() + self.assertEqual(set(["inner/mean", "outer-mean"]), set(results.keys())) + self.assertEqual(3.0, results["inner/mean"].numpy()) + self.assertEqual(550.0, results["outer-mean"].numpy()) + + def testMetricVariables(self): + e = DelegatingEvaluator(PrefixLModel()) + e({"inner": 2.0, "outer": 100.0}) + prefix_count = {} + for v in e.metric_variables: + p = v.name.split("/")[0] + prefix_count[p] = prefix_count.get(p, 0) + 1 + self.assertEqual({"outer_mean": 2, "mean": 2}, prefix_count) + + def testDatasetEager(self): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + results = e.evaluate_on_dataset(ds) + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"].numpy()) + + def testDatasetGraph(self): + with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + init_op, call_op, results_op = e.evaluate_on_dataset(ds) + results = e.run_evaluation(init_op, call_op, results_op) + self.assertEqual(set(["mean"]), set(results.keys())) + self.assertEqual(6.0, results["mean"]) + + def testWriteSummariesGraph(self): + with context.graph_mode(), ops.Graph().as_default(), self.test_session(): + e = SimpleEvaluator(IdentityModel()) + ds = dataset_ops.Dataset.from_tensor_slices([3.0, 5.0, 7.0, 9.0]) + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + init_op, call_op, results_op = e.evaluate_on_dataset( + ds, summary_logdir=logdir) + variables.global_variables_initializer().run() + e.run_evaluation(init_op, call_op, results_op) + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 6.0) + + def testModelProperty(self): + m = IdentityModel() + e = SimpleEvaluator(m) + self.assertIs(m, e.model) + + def testMetricsProperty(self): + e = DelegatingEvaluator(PrefixLModel()) + names = set([(p, m.name) for p, m in e.metrics]) + self.assertEqual(set([("", "outer-mean"), ("inner/", "mean")]), names) + + def testSharedMetric(self): + + class MetricArgEvaluator(evaluator.Evaluator): + + def __init__(self, model, m): + super(MetricArgEvaluator, self).__init__(model) + self.m = self.track_metric(m) + + metric = metrics.Mean("mean") + model = IdentityModel() + e = MetricArgEvaluator(model, metric) + with self.assertRaisesRegexp(ValueError, "already added"): + MetricArgEvaluator(model, metric) + del e + + def testMetricTrackedTwice(self): + + class MetricTwiceEvaluator(evaluator.Evaluator): + + def __init__(self, model): + super(MetricTwiceEvaluator, self).__init__(model) + self.m = self.track_metric(metrics.Mean("mean")) + self.track_metric(self.m) # okay to track same metric again + + MetricTwiceEvaluator(IdentityModel()) + + +class SparseSoftmaxEvaluatorTest(test.TestCase): + + def testSimple(self): + e = evaluator.SparseSoftmaxEvaluator(IdentityModel()) + e({e.loss_key: 1.0, e.label_key: 5, e.predicted_class_key: 5}) + e({e.loss_key: [0.0, 3.0, 4.0], + e.label_key: [1, 2, 3], + e.predicted_class_key: [1, 1, 3]}) + results = e.all_metric_results() + self.assertEqual(set(["Avg Loss", "Accuracy"]), set(results.keys())) + self.assertEqual(2.0, results["Avg Loss"].numpy()) + self.assertEqual(0.75, results["Accuracy"].numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/examples/BUILD b/tensorflow/contrib/eager/python/examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aa21a6ab994acf929890ecebc07a86cf7ebf97db --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/BUILD @@ -0,0 +1,15 @@ +# TensorFlow code for training gradient boosted trees. +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +py_library( + name = "examples_pip", + deps = [ + "//tensorflow/contrib/eager/python/examples/linear_regression", + "//tensorflow/contrib/eager/python/examples/mnist", + "//tensorflow/contrib/eager/python/examples/resnet50", + "//tensorflow/contrib/eager/python/examples/rnn_colorbot", + "//tensorflow/contrib/eager/python/examples/rnn_ptb", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/BUILD b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..bab7ad0c701b2110fda9a8d27792fd361a5fc1c0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/linear_regression/BUILD @@ -0,0 +1,25 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "linear_regression", + srcs = ["linear_regression.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +cuda_py_test( + name = "linear_regression_test", + size = "small", + srcs = ["linear_regression_test.py"], + additional_deps = [ + ":linear_regression", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..d0130ebd118dbaff4f0161c8b2528764c6103e02 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression.py @@ -0,0 +1,157 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""TensorFlow Eager Execution Example: Linear Regression. + +This example shows how to use TensorFlow Eager Execution to fit a simple linear +regression model using some synthesized data. Specifically, it illustrates how +to define the forward path of the linear model and the loss function, as well +as how to obtain the gradients of the loss function with respect to the +variables and update the variables with the gradients. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe + + +class LinearModel(tfe.Network): + """A TensorFlow linear regression model. + + Uses TensorFlow's eager execution. + + For those familiar with TensorFlow graphs, notice the absence of + `tf.Session`. The `forward()` method here immediately executes and + returns output values. The `loss()` method immediately compares the + output of `forward()` with the target adn returns the MSE loss value. + The `fit()` performs gradient-descent training on the model's weights + and bias. + """ + + def __init__(self): + """Constructs a LinearModel object.""" + super(LinearModel, self).__init__() + self._hidden_layer = self.track_layer(tf.layers.Dense(1)) + + def call(self, xs): + """Invoke the linear model. + + Args: + xs: input features, as a tensor of size [batch_size, ndims]. + + Returns: + ys: the predictions of the linear mode, as a tensor of size [batch_size] + """ + return self._hidden_layer(xs) + + +def fit(model, dataset, optimizer, verbose=False, logdir=None): + """Fit the linear-regression model. + + Args: + model: The LinearModel to fit. + dataset: The tf.data.Dataset to use for training data. + optimizer: The TensorFlow Optimizer object to be used. + verbose: If true, will print out loss values at every iteration. + logdir: The directory in which summaries will be written for TensorBoard + (optional). + """ + + # The loss function to optimize. + def mean_square_loss(xs, ys): + return tf.reduce_mean(tf.square(model(xs) - ys)) + + loss_and_grads = tfe.implicit_value_and_gradients(mean_square_loss) + + tf.train.get_or_create_global_step() + if logdir: + # Support for TensorBoard summaries. Once training has started, use: + # tensorboard --logdir= + summary_writer = tf.contrib.summary.create_summary_file_writer(logdir) + + # Training loop. + for i, (xs, ys) in enumerate(tfe.Iterator(dataset)): + loss, grads = loss_and_grads(xs, ys) + if verbose: + print("Iteration %d: loss = %s" % (i, loss.numpy())) + + optimizer.apply_gradients(grads, global_step=tf.train.get_global_step()) + + if logdir: + with summary_writer.as_default(): + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", loss) + + +def synthetic_dataset(w, b, noise_level, batch_size, num_batches): + """tf.data.Dataset that yields synthetic data for linear regression.""" + + # w is a matrix with shape [N, M] + # b is a vector with shape [M] + # So: + # - Generate x's as vectors with shape [batch_size N] + # - y = tf.matmul(x, W) + b + noise + def batch(_): + x = tf.random_normal([batch_size, tf.shape(w)[0]]) + y = tf.matmul(x, w) + b + noise_level * tf.random_normal([]) + return x, y + + with tf.device("/device:CPU:0"): + return tf.data.Dataset.range(num_batches).map(batch) + + +def main(_): + tfe.enable_eager_execution() + # Ground-truth constants. + true_w = [[-2.0], [4.0], [1.0]] + true_b = [0.5] + noise_level = 0.01 + + # Training constants. + batch_size = 64 + learning_rate = 0.1 + + print("True w: %s" % true_w) + print("True b: %s\n" % true_b) + + model = LinearModel() + dataset = synthetic_dataset(true_w, true_b, noise_level, batch_size, 20) + + device = "gpu:0" if tfe.num_gpus() else "cpu:0" + print("Using device: %s" % device) + with tf.device(device): + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + fit(model, dataset, optimizer, verbose=True, logdir=FLAGS.logdir) + + print("\nAfter training: w = %s" % model.variables[0].numpy()) + print("\nAfter training: b = %s" % model.variables[1].numpy()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--logdir", + type=str, + default=None, + help="logdir in which TensorBoard summaries will be written (optional).") + FLAGS, unparsed = parser.parse_known_args() + + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py new file mode 100644 index 0000000000000000000000000000000000000000..39e7aabd7be04ba36a786a4c08d0df6c2ce916d0 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/linear_regression/linear_regression_test.py @@ -0,0 +1,119 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for linear regression example under TensorFlow eager execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import glob +import os +import shutil +import tempfile +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.linear_regression import linear_regression + + +def device(): + return "/device:GPU:0" if tfe.num_gpus() > 0 else "/device:CPU:0" + + +class LinearRegressionTest(tf.test.TestCase): + + def setUp(self): + super(LinearRegressionTest, self).setUp() + self._tmp_logdir = tempfile.mkdtemp() + + def tearDown(self): + shutil.rmtree(self._tmp_logdir) + super(LinearRegressionTest, self).tearDown() + + def testSyntheticDataset(self): + true_w = tf.random_uniform([3, 1]) + true_b = [1.0] + batch_size = 10 + num_batches = 2 + noise_level = 0. + dataset = linear_regression.synthetic_dataset(true_w, true_b, noise_level, + batch_size, num_batches) + + it = tfe.Iterator(dataset) + for _ in range(2): + (xs, ys) = it.next() + self.assertEqual((batch_size, 3), xs.shape) + self.assertEqual((batch_size, 1), ys.shape) + self.assertEqual(tf.float32, xs.dtype) + self.assertEqual(tf.float32, ys.dtype) + with self.assertRaises(StopIteration): + it.next() + + def testLinearRegression(self): + true_w = [[1.0], [-0.5], [2.0]] + true_b = [1.0] + + model = linear_regression.LinearModel() + dataset = linear_regression.synthetic_dataset( + true_w, true_b, noise_level=0., batch_size=64, num_batches=40) + + with tf.device(device()): + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1) + linear_regression.fit(model, dataset, optimizer, logdir=self._tmp_logdir) + + self.assertAllClose(true_w, model.variables[0].numpy(), rtol=1e-2) + self.assertAllClose(true_b, model.variables[1].numpy(), rtol=1e-2) + self.assertTrue(glob.glob(os.path.join(self._tmp_logdir, "events.out.*"))) + + +class EagerLinearRegressionBenchmark(tf.test.Benchmark): + + def benchmarkEagerLinearRegression(self): + num_batches = 200 + batch_size = 64 + dataset = linear_regression.synthetic_dataset( + w=tf.random_uniform([3, 1]), + b=tf.random_uniform([1]), + noise_level=0.01, + batch_size=batch_size, + num_batches=num_batches) + burn_in_dataset = dataset.take(10) + + model = linear_regression.LinearModel() + + with tf.device(device()): + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1) + + # Perform burn-in. + linear_regression.fit(model, burn_in_dataset, optimizer) + + start_time = time.time() + linear_regression.fit(model, dataset, optimizer) + wall_time = time.time() - start_time + + examples_per_sec = num_batches * batch_size / wall_time + self.report_benchmark( + name="eager_train_%s" % + ("gpu" if tfe.num_gpus() > 0 else "cpu"), + iters=num_batches, + extras={"examples_per_sec": examples_per_sec}, + wall_time=wall_time) + + +if __name__ == "__main__": + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/BUILD b/tensorflow/contrib/eager/python/examples/mnist/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..c61ec2dbae60a782c0e6589701554b045dcb92ae --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/mnist/BUILD @@ -0,0 +1,36 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow/examples/tutorials/mnist:input_data", + ], +) + +cuda_py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + additional_deps = [ + ":mnist", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "mnist_graph_test", + srcs = ["mnist_graph_test.py"], + additional_deps = [ + ":mnist", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/mnist/README.md b/tensorflow/contrib/eager/python/examples/mnist/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e987996b88ccf54a322749aadec4f9840760a90f --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/mnist/README.md @@ -0,0 +1,10 @@ +Classification model for the MNIST dataset using eager execution. + +To run: + +``` +python mnist.py +``` + +`mnist_graph_test.py` demonstrates that the same code that is executed eagerly +in `mnist.py` is used to construct a TensorFlow graph. diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist.py b/tensorflow/contrib/eager/python/examples/mnist/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb7d5a9002787f6544d383de58150661ac2bde3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist.py @@ -0,0 +1,270 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A deep MNIST classifier using convolutional layers. + +Sample usage: + python mnist.py --help +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import functools +import os +import sys +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.examples.tutorials.mnist import input_data + +FLAGS = None + + +class MNISTModel(tfe.Network): + """MNIST Network. + + Network structure is equivalent to: + https://github.com/tensorflow/tensorflow/blob/r1.4/tensorflow/examples/tutorials/mnist/mnist_deep.py + and + https://github.com/tensorflow/models/blob/master/tutorials/image/mnist/convolutional.py + + But written using the tf.layers API. + """ + + def __init__(self, data_format): + """Creates a model for classifying a hand-written digit. + + Args: + data_format: Either 'channels_first' or 'channels_last'. + 'channels_first' is typically faster on GPUs while 'channels_last' is + typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + """ + super(MNISTModel, self).__init__(name='') + if data_format == 'channels_first': + self._input_shape = [-1, 1, 28, 28] + else: + assert data_format == 'channels_last' + self._input_shape = [-1, 28, 28, 1] + self.conv1 = self.track_layer( + tf.layers.Conv2D(32, 5, data_format=data_format, activation=tf.nn.relu)) + self.conv2 = self.track_layer( + tf.layers.Conv2D(64, 5, data_format=data_format, activation=tf.nn.relu)) + self.fc1 = self.track_layer(tf.layers.Dense(1024, activation=tf.nn.relu)) + self.fc2 = self.track_layer(tf.layers.Dense(10)) + self.dropout = self.track_layer(tf.layers.Dropout(0.5)) + self.max_pool2d = self.track_layer( + tf.layers.MaxPooling2D( + (2, 2), (2, 2), padding='SAME', data_format=data_format)) + + def call(self, inputs, training): + """Computes labels from inputs. + + Users should invoke __call__ to run the network, which delegates to this + method (and not call this method directly). + + Args: + inputs: A batch of images as a Tensor with shape [batch_size, 784]. + training: True if invoked in the context of training (causing dropout to + be applied). False otherwise. + + Returns: + A Tensor with shape [batch_size, 10] containing the predicted logits + for each image in the batch, for each of the 10 classes. + """ + + x = tf.reshape(inputs, self._input_shape) + x = self.conv1(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.max_pool2d(x) + x = tf.layers.flatten(x) + x = self.fc1(x) + if training: + x = self.dropout(x) + x = self.fc2(x) + return x + + +def loss(predictions, labels): + return tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + logits=predictions, labels=labels)) + + +def compute_accuracy(predictions, labels): + return tf.reduce_sum( + tf.cast( + tf.equal( + tf.argmax(predictions, axis=1, + output_type=tf.int64), + tf.argmax(labels, axis=1, + output_type=tf.int64)), + dtype=tf.float32)) / float(predictions.shape[0].value) + + +def train_one_epoch(model, optimizer, dataset, log_interval=None): + """Trains model on `dataset` using `optimizer`.""" + + tf.train.get_or_create_global_step() + + def model_loss(labels, images): + prediction = model(images, training=True) + loss_value = loss(prediction, labels) + tf.contrib.summary.scalar('loss', loss_value) + tf.contrib.summary.scalar('accuracy', + compute_accuracy(prediction, labels)) + return loss_value + + for (batch, (images, labels)) in enumerate(tfe.Iterator(dataset)): + with tf.contrib.summary.record_summaries_every_n_global_steps(10): + batch_model_loss = functools.partial(model_loss, labels, images) + optimizer.minimize( + batch_model_loss, global_step=tf.train.get_global_step()) + if log_interval and batch % log_interval == 0: + print('Batch #%d\tLoss: %.6f' % (batch, batch_model_loss())) + + +def test(model, dataset): + """Perform an evaluation of `model` on the examples from `dataset`.""" + avg_loss = tfe.metrics.Mean('loss') + accuracy = tfe.metrics.Accuracy('accuracy') + + for (images, labels) in tfe.Iterator(dataset): + predictions = model(images, training=False) + avg_loss(loss(predictions, labels)) + accuracy(tf.argmax(predictions, axis=1, output_type=tf.int64), + tf.argmax(labels, axis=1, output_type=tf.int64)) + print('Test set: Average loss: %.4f, Accuracy: %4f%%\n' % + (avg_loss.result(), 100 * accuracy.result())) + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar('loss', avg_loss.result()) + tf.contrib.summary.scalar('accuracy', accuracy.result()) + + +def load_data(data_dir): + """Returns training and test tf.data.Dataset objects.""" + data = input_data.read_data_sets(data_dir, one_hot=True) + train_ds = tf.data.Dataset.from_tensor_slices((data.train.images, + data.train.labels)) + test_ds = tf.data.Dataset.from_tensors((data.test.images, data.test.labels)) + return (train_ds, test_ds) + + +def main(_): + tfe.enable_eager_execution() + + (device, data_format) = ('/gpu:0', 'channels_first') + if FLAGS.no_gpu or tfe.num_gpus() <= 0: + (device, data_format) = ('/cpu:0', 'channels_last') + print('Using device %s, and data format %s.' % (device, data_format)) + + # Load the datasets + (train_ds, test_ds) = load_data(FLAGS.data_dir) + train_ds = train_ds.shuffle(60000).batch(FLAGS.batch_size) + + # Create the model and optimizer + model = MNISTModel(data_format) + optimizer = tf.train.MomentumOptimizer(FLAGS.lr, FLAGS.momentum) + + if FLAGS.output_dir: + train_dir = os.path.join(FLAGS.output_dir, 'train') + test_dir = os.path.join(FLAGS.output_dir, 'eval') + tf.gfile.MakeDirs(FLAGS.output_dir) + else: + train_dir = None + test_dir = None + summary_writer = tf.contrib.summary.create_summary_file_writer( + train_dir, flush_millis=10000) + test_summary_writer = tf.contrib.summary.create_summary_file_writer( + test_dir, flush_millis=10000, name='test') + checkpoint_prefix = os.path.join(FLAGS.checkpoint_dir, 'ckpt') + + with tf.device(device): + for epoch in range(1, 11): + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(FLAGS.checkpoint_dir)): + global_step = tf.train.get_or_create_global_step() + start = time.time() + with summary_writer.as_default(): + train_one_epoch(model, optimizer, train_ds, FLAGS.log_interval) + end = time.time() + print('\nTrain time for epoch #%d (global step %d): %f' % ( + epoch, global_step.numpy(), end - start)) + with test_summary_writer.as_default(): + test(model, test_ds) + all_variables = ( + model.variables + + optimizer.variables() + + [global_step]) + tfe.Saver(all_variables).save( + checkpoint_prefix, global_step=global_step) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--data-dir', + type=str, + default='/tmp/tensorflow/mnist/input_data', + help='Directory for storing input data') + parser.add_argument( + '--batch-size', + type=int, + default=64, + metavar='N', + help='input batch size for training (default: 64)') + parser.add_argument( + '--log-interval', + type=int, + default=10, + metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument( + '--output_dir', + type=str, + default=None, + metavar='N', + help='Directory to write TensorBoard summaries') + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/tensorflow/mnist/checkpoints/', + metavar='N', + help='Directory to save checkpoints in (once per epoch)') + parser.add_argument( + '--lr', + type=float, + default=0.01, + metavar='LR', + help='learning rate (default: 0.01)') + parser.add_argument( + '--momentum', + type=float, + default=0.5, + metavar='M', + help='SGD momentum (default: 0.5)') + parser.add_argument( + '--no-gpu', + action='store_true', + default=False, + help='disables GPU usage even if a GPU is available') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1af26553120b34d4682b17b1c29c81dc65e421d4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py @@ -0,0 +1,65 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf +from tensorflow.contrib.eager.python.examples.mnist import mnist + + +def data_format(): + return "channels_first" if tf.test.is_gpu_available() else "channels_last" + + +class MNISTGraphTest(tf.test.TestCase): + + def testTrainGraph(self): + # The MNISTModel class can be executed eagerly (as in mnist.py and + # mnist_test.py) and also be used to construct a TensorFlow graph, which is + # then trained in a session. + with tf.Graph().as_default(): + # Generate some random data. + batch_size = 64 + images = np.random.randn(batch_size, 784).astype(np.float32) + digits = np.random.randint(low=0, high=10, size=batch_size) + labels = np.zeros((batch_size, 10)) + labels[np.arange(batch_size), digits] = 1. + + # Create a model, optimizer, and dataset as would be done + # for eager execution as well. + model = mnist.MNISTModel(data_format()) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + dataset = tf.data.Dataset.from_tensors((images, labels)) + + # Define the loss tensor (as opposed to a loss function when + # using eager execution). + (images, labels) = dataset.make_one_shot_iterator().get_next() + predictions = model(images, training=True) + loss = mnist.loss(predictions, labels) + + train_op = optimizer.minimize(loss) + init = tf.global_variables_initializer() + with tf.Session() as sess: + # Variables have to be initialized in the session. + sess.run(init) + # Train using the optimizer. + sess.run(train_op) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py new file mode 100644 index 0000000000000000000000000000000000000000..205709fe2edd3c260c30a84b624e322e120edf8e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/mnist/mnist_test.py @@ -0,0 +1,62 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.mnist import mnist + + +def device(): + return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" + + +def data_format(): + return "channels_first" if tfe.num_gpus() else "channels_last" + + +def random_dataset(): + batch_size = 64 + images = tf.random_normal([batch_size, 784]) + digits = tf.random_uniform([batch_size], minval=0, maxval=10, dtype=tf.int32) + labels = tf.one_hot(digits, 10) + return tf.data.Dataset.from_tensors((images, labels)) + + +class MNISTTest(tf.test.TestCase): + + def testTrainOneEpoch(self): + model = mnist.MNISTModel(data_format()) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.train_one_epoch(model, optimizer, dataset) + + def testTest(self): + model = mnist.MNISTModel(data_format()) + dataset = random_dataset() + with tf.device(device()): + tf.train.get_or_create_global_step() + mnist.test(model, dataset) + + +if __name__ == "__main__": + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..01616f2e7dbab8084153e6554ce0e64c13f5d710 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/1_basics.ipynb @@ -0,0 +1,529 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "U9i2Dsh-ziXr" + }, + "source": [ + "# Eager Execution Tutorial: Basics\n", + "\n", + "This notebook introduces the basics of using TensorFlow's eager execution capabilities. It covers concepts such as:\n", + "\n", + "* Importing required packages\n", + "* Enabling eager execution\n", + "* Creating and using TensorFlow Tensors and Variables\n", + "* Using TensorFlow interactively\n", + "* Using GPUs with eager execution enabled\n", + "\n", + "This notebook does *not* cover modeling topics, such as gradients." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "z1JcS5iBXMRO" + }, + "source": [ + "# Step 1: Import Eager\n", + "\n", + "The key imports for eager execution are the following:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RlIWhyeLoYnG" + }, + "outputs": [], + "source": [ + "# Import TensorFlow.\n", + "import tensorflow as tf\n", + "\n", + "# Import TensorFlow eager execution support (subject to future changes).\n", + "import tensorflow.contrib.eager as tfe" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "H9UySOPLXdaw" + }, + "source": [ + "# Step 2: Enable eager execution\n", + "\n", + "All future TensorFlow calls will execute the\n", + "underlying TensorFlow ops immediately:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WPTUfGq6kJ5w" + }, + "outputs": [], + "source": [ + "tfe.enable_eager_execution()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "twBfWd5xyu_d" + }, + "source": [ + "# Step 3: Interactively Use TensorFlow!\n", + "\n", + "Now you can call TensorFlow functions and get results, immediately! No more `tf.Sessions`!\n", + "\n", + "TensorFlow will automatically wrap native Python types for you with operator overloading for TensorFlow Tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ngUe237Wt48W" + }, + "outputs": [], + "source": [ + "print(tf.add(1, 2))\n", + "print(tf.add([1, 2], [3, 4]))\n", + "print(tf.square(5))\n", + "print(tf.reduce_sum([1, 2, 3]))\n", + "print(tf.encode_base64(\"hello world\"))\n", + "print(\"\")\n", + "\n", + "x = tf.constant(2)\n", + "y = tf.constant(3)\n", + "print(x * y + 1)\n", + "\n", + "# Most TensorFlow ops are directly usable with eager execution, giving\n", + "# results immediately.\n", + "print(tf.contrib.signal.hamming_window(x * y + 1))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IDY4WsYRhP81" + }, + "source": [ + "Numpy arrays are supported, too:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "lCUWzso6mbqR" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "ones = np.ones([3, 3])\n", + "\n", + "print(\"numpy 3x3 matrix of 1s:\")\n", + "print(ones)\n", + "print(\"\")\n", + "\n", + "print(\"Multiplied by 42:\")\n", + "print(tf.multiply(ones, 42))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "PBNP8yTRfu_X" + }, + "source": [ + "# Step 4: Define and Print TensorFlow Variables\n", + "\n", + "To define TensorFlow variables, use the `get_variable()` function as follows:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "3Twf_Rw-gQFM" + }, + "outputs": [], + "source": [ + "x = tf.get_variable(name=\"x\", shape=[], dtype=tf.float32, initializer=tf.zeros_initializer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "45G7094TxsMb" + }, + "source": [ + "## Printing TensorFlow Variables" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "UJBJeZ5XxuwA" + }, + "outputs": [], + "source": [ + "# This does NOT print the Variable's actual value:\n", + "print(\"Printing a TensorFlow Variable:\")\n", + "print(x)\n", + "print(\"\")\n", + "\n", + "# A TensorFlow variable represents a reference to a tensor.\n", + "# The `read_value()` method provides access to the current value of the\n", + "# variable. Tensorflow Variables are automatically initialized according to the\n", + "# semantics defined in tf.get_variable().\n", + "print(\"Printing a TensorFlow Variable's value using .read_value():\")\n", + "print(x.read_value())\n", + "print(\"\")\n", + "\n", + "print(\"Printing a TensorFlow Variable's value using .read_value().numpy():\")\n", + "print(x.read_value().numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "2njjWHcTpBEn" + }, + "source": [ + "## Changing a TensorFlow Variable's value\n", + "\n", + "To change a TensorFlow Variable's value, use its `.assign()` or `.assign_add()` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "v3wr6Erbo_hB" + }, + "outputs": [], + "source": [ + "x.assign(42)\n", + "print(x.read_value())\n", + "\n", + "x.assign_add(3)\n", + "print(x.read_value())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "uhtynjHVpTB5" + }, + "source": [ + "## Use a Variable just like any other Tensor" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "7PbktdnHoehR" + }, + "outputs": [], + "source": [ + "print(x + 3)\n", + "\n", + "# This code will broadcast the value across the list of numbers:\n", + "print(x * [1, 2, 4])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GVChqwlwy1SI" + }, + "source": [ + "# Step 5: Debug Errors with Instant Feedback\n", + "\n", + "TensorFlow's eager execution helps you identify and debug runtime issues through interactive exploration of code snippets.\n", + "\n", + "Below, we'll define a length-4 vector, and attempt two `tf.slice()` operations,\n", + "one being legal and the other being illegal, leading to a runtime error that is\n", + "raised immediately." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "23ap04N0v4k0" + }, + "outputs": [], + "source": [ + "vector = tf.constant([10.0, 20.0, 30.0, 40.0])" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "FCUMsIYxxRRa" + }, + "outputs": [], + "source": [ + "# Works, because the values of `begin` and `size` (the 2nd and 3rd input\n", + "# arguments) are within the bound of `vector`.\n", + "print(tf.slice(vector, [1], [3]))" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "T8me2oCNxpFp" + }, + "outputs": [], + "source": [ + "# The following does NOT work, because the value of `size` (the 3rd\n", + "# argument) causes the indices to go out of the bounds of `vector`. The\n", + "# error is raised immediately.\n", + "try:\n", + " print(tf.slice(vector, [1], [4]))\n", + "except tf.OpError as e:\n", + " print(\"Caught error: %s\" % e)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "irxJhAgar84v" + }, + "source": [ + "# Step 6: Using the GPU\n", + "\n", + "You can place Tensors on the GPU by calling a Tensor's `.gpu()` method.\n", + "\n", + "The first operation executing on the GPU may be slow as TensorFlow initializes. Subsequent uses will be much faster." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "7J4N9baqaKCL" + }, + "outputs": [], + "source": [ + "# The example code from here on will work only if your notebook\n", + "# is running on a machine with a functional CUDA GPU. The following\n", + "# line checks that.\n", + "is_gpu_available = tfe.num_gpus() \u003e 0\n", + "\n", + "# Create some Tensors\n", + "SIZE = 1000\n", + "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", + "\n", + "if is_gpu_available:\n", + " gpu_tensor = cpu_tensor.gpu()" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "4E-2n7VbzY1n" + }, + "outputs": [], + "source": [ + "# Time a CPU-based matrix multiplication\n", + "\n", + "print(\"Time to conduct matmul on CPU:\")\n", + "%time tf.matmul(cpu_tensor, cpu_tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "vbSFW-T5zhZF" + }, + "outputs": [], + "source": [ + "# Time GPU-based matrix multiplications.\n", + "\n", + "if is_gpu_available:\n", + " # First use of the GPU will be slow:\n", + " print(\"Time to conduct first matmul on GPU:\")\n", + " %time tf.matmul(gpu_tensor, gpu_tensor)\n", + " print()\n", + "\n", + " # Subsequent uses are much faster:\n", + " print(\"Time to conduct second matmul on GPU:\")\n", + " %time tf.matmul(gpu_tensor, gpu_tensor)" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "E5pIOe3Rz7iW" + }, + "outputs": [], + "source": [ + "# Second timing demo for GPUs, after it has been used once:\n", + "\n", + "cpu_tensor = tf.random_normal([SIZE, SIZE])\n", + "print(\"Time to conduct CPU matmul:\")\n", + "%time tf.matmul(cpu_tensor, cpu_tensor)\n", + "print()\n", + "\n", + "if is_gpu_available:\n", + " gpu_tensor = cpu_tensor.gpu()\n", + " print(\"Time to conduct GPU matmul:\")\n", + " %time tf.matmul(gpu_tensor, gpu_tensor)" + ] + } + ], + "metadata": { + "colab": { + "default_view": {}, + "name": "Eager Execution Tutorial: Basics", + "provenance": [ + { + "file_id": "0B0kLcpwLFwKEVm9XNkFueGk4bTg", + "timestamp": 1504118841551 + } + ], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..3b7e2cd435e7f34cb950545a9fe5ee6eafefde7e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/2_gradients.ipynb @@ -0,0 +1,864 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vDJ4XzMqodTy" + }, + "source": [ + "# Eager Execution: Working with Gradients\n", + "\n", + "This notebook demonstrates:\n", + "\n", + "* How to get gradients using TensorFlow's eager execution capabilities\n", + "* How to apply the gradients so you can update your variables" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GQJysDM__Qb0" + }, + "source": [ + "# Setup: Import eager and enable eager execution.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "OiMPZStlibBv" + }, + "outputs": [], + "source": [ + "# Import TensorFlow.\n", + "import tensorflow as tf\n", + "\n", + "# Import TensorFlow eager execution support (subject to future changes).\n", + "import tensorflow.contrib.eager as tfe\n", + "\n", + "# Enable eager execution.\n", + "tfe.enable_eager_execution()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1CLWJl0QliB0" + }, + "source": [ + "# Fitting a Simple Linear Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "-39gouo7mtgu" + }, + "source": [ + "## Step 1: Synthesize some data\n", + "\n", + "To demonstrate fitting a model with TensorFlow's eager execution, we'll fit a linear model to some synthesized data (which includes some noise).\n", + "\n", + "In the code, we use the variable names `w` and `b` to represent the single weight and bias we'll use to fit our model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "rQsdCg9PfIL-" + }, + "outputs": [], + "source": [ + "# The constants we'll try to fit our variables to:\n", + "true_w = 3\n", + "true_b = 2\n", + "\n", + "NUM_EXAMPLES = 1000\n", + "\n", + "# Our inputs:\n", + "inputs = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "\n", + "# Our labels, with noise:\n", + "noise = tf.random_normal(shape=[NUM_EXAMPLES, 1])\n", + "labels = inputs * true_w + true_b + noise" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 360, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 127, + "status": "ok", + "timestamp": 1505502830690, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "O4lsC4ckAcar", + "outputId": "2f760690-cafb-4777-b970-91d839f99faf" + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAesAAAFXCAYAAACC+2avAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzsnXt8VPWd99+TK7kykxtJQIebqZfaqogtrhKNa1ooEKl9\nCrpVn9ZNW6x9VWsbCi7aVUt01NZ9tq21KVZlFey2YkQNohhj3QWK2liCF5RIBCc3yEwmIZnMTOY8\nf/zmzJwzSSBAYibh+369eIU5c87vXLh8zvdu0TRNQxAEQRCEmCVurC9AEARBEISjI2ItCIIgCDGO\niLUgCIIgxDgi1oIgCIIQ44hYC4IgCEKMI2ItCIIgCDHOiIj16tWrufjii1m8eHF4269//Wvmz5/P\n0qVLWbp0Ka+//vpInEoQBEEQTjksI1Fn/eabb5KWlkZFRQWbN28GlFinpaXx7W9/+6QvUhAEQRBO\nZUbEsr7wwgvJzMwcsF36rQiCIAjCyTOqMesnn3ySsrIybr/9drq6ukbzVIIgCIIwYRk1sb722mt5\n5ZVXqK6uJicnh8rKytE6lSAIgiBMaEZNrLOysrBYLAB885vfZPfu3cc8RtzmgiAIgjCQhJFaKFpo\n29vbyc3NBeDll1+mqKjomGtYLBba2yeuuzw3N0Pubxwzke9vIt8byP2Nd06F+zsWIyLWt912Gzt3\n7sTtdnPZZZfxwx/+kJ07d/Lee+8RFxfH1KlTueuuu0biVIIgCIJwyjEiYv3ggw8O2Hb11VePxNKC\nIAiCcMojHcwEQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBOHXo6HCzcmUt\nTU2Z2O2dOBwl2GzWsb6smEfEWhAEQfjMWLmylurq6wAL9fUasJ6qqqVjfVkxj7jBBUEQhM+MpqZM\nwBL6ZAl9Fo6FiLUgCILwmWG3dwJa6JOG3e4Zy8sZN4gbXBAEQfjMcDhKgPWhmLUHh+Pysb6kcYGI\ntSAIgvCZYbNZJUZ9AogbXBAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFr\nQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfE\nWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYR\nsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhxRKwFQRAEIcYRsRYEQRCEGEfEWhAEQRBiHBFrQRAEQYhx\nRKwFQRAEIcYRsRYEQRCEGCdhrC9AEARBODE6OtysXFmL02mjsLADh6MEm806rGOamjKx2zuHdYww\n9oyIWK9evZrXXnuN7OxsNm/eDEBnZye33norn376KdOmTeOhhx4iIyNjJE4nCIIgACtX1lJdfR1g\nATRgPVVVS037RIuzz9dDTc33AQv19YMfI8QeI+IG//rXv866detM237/+98zb948XnrpJb70pS/x\nyCOPjMSpBEEQhBBNTZkooQawhD6b0QW9vv4qqquvZ/v27mMeI8QeIyLWF154IZmZ5j/wbdu2sXSp\neltbunQpr7zyykicShAEQQhht3eiLGoADbvdM2CfaEGH7GMeI8Qeoxaz7ujoICcnB4Dc3FxcLtdo\nnUoQBOGUxOEoAdaHYtYuHI7LAbPru61tD1AM2AAXkyY5sVr/CBxi3rwMHI5FY3cDwrCJuQSz3NyJ\nHdeW+xvfTOT7m8j3BhPz/uLi+klOTgQgOTmBnJwMsrIyuPnm5w2x7DKmTbuPgoJzaG7ew8GDt6PH\nuDMyNpKdncFNNz3Pxx+nM2NGFw8/vJCsrNhLOJuIf37Hw6iJdXZ2NocOHSInJ4f29naysrKGdVx7\ne9doXdKYk5ubIfc3jpnI9zeR7w0m7v2Vlz8XFuVduzT6+lSy2N69KRhd3zk5Z/LCC5dRWtrPwYOR\n7Xv3pnDjjYOvEUtM1D8/neG8iIxYnbWmaabPJSUlPPPMMwBs2rSJK664YqROJQiCIDB0gtlQsezB\ntg8nSU0Ye0bEsr7tttvYuXMnbrebyy67jB/+8Id897vf5Uc/+hF/+ctfKCws5D/+4z9G4lSCIAhC\nCLu9M1R+pdzauijrsWxVruUJx7JXrZrDrl2VuFzTsNkOsnr1EtaufWvQNYTYYkTE+sEHHxx0+2OP\nPTYSywuCIAiDMFSCmc1mHdSVXVn5Nk7nKsBCb6/G2rXrhxR2IbaIuQQzQRAEYXjoojxYTHewTmWD\nubyHEnYhthCxFgRBmIAYu5vpncrsdk1c3uMUEWtBEIQYYai+3SfSz3swK/rpp+cgLu/xiYi1IAhC\njDCYNVxVtXTI7UdjsOQzcXmPX0SsBUEQYoShyqhOpLxKEscmFiLWgiAIMcJQpVjm7S7a2t6ltJSw\nS3ywphojYUXLOM3YQcRaEAQhRhjKGjZub2t7F6dzFU6nconX1T1AaelU7r770mEL6XBF+ETc78Lo\nIGItCIIQIwxlDRu3l5aC0xlxibvdZ/KnPy06rjahwxVh6W4WO4xYu1FBEARh9IluGQpqPvXxCOlw\nRXg4IziFzwaxrAVBEMYRuku8ttaPx5MCLAQ0CgoODXuNoWLjQ51LktTGHhFrQRCEcYTuEr/hhv+i\npiYBeBY4xNtvu3G53ANiz4PFp4crwlLqFTuIWAuCIIxDmpsLgF5gOWChtVWjomJg7Hmo+LSI8PhC\nxFoQBGEcEG0hFxT4qK+fwrFiz5IkNjGQBDNBEITPiI4ON+Xlmygt3UZ5+TO4XO5hf69byPX1V1Fd\nfT0QoLBwN8dKAJMksYmBWNaCIAifEdEu6V27KqmtvS4cZz5aSVW0hdzcXEBt7SIqKgaOyDQiSWIT\nAxFrQRCEz4howXU6P09FRe2Qgmx0WRcUNFNf/xSQAXgoKPAcdUSmznCTxKRbWWwjbnBBEIRBOJbL\n+kQwu6RdwLts3Up4/aO7rBOBa4DFwLWhzyNHtJu9oqJ2RNcXTg6xrAVBEAZhNFptOhwl7NpVidP5\neeBdYCW9vRaqq9X6DkcJfX3r2LEjDjiMz5cWLsdqbs7B7AbPOalriUYS0WIbsawFQRAG4XjFaziW\nuM1mpbb2OsrK3KSkFA5Y32azkpychNv9bdzun1JTsyJs4UZb3QUFLeHzLVv21Elb/pKIFtuIZS0I\ngjAIw+3ypROxxDupr3+RurqXKS6OHxD71WPI5eXPhCxq8/pDvSREJ4r5fAkmy/94eoMb0WPV+/Yl\nUFhYSXZ2ETNn9kgiWowhYi0IgjAIx5tFHRHZGmABbvcWqqvT2LXrCWprrx+QrOVwlODzPcL27V1A\nNj5ffzhuPdhLQnSiWGnpNoZr+R8teczo7geNuXNlslYsImItCIIwCMfbajMisunAFvTOYk7n4kE7\ni9lsVpKSUnG7vwdYqKnRSEpaf9SXBKPotrXtAcoYjuV/PCVhEquOTUSsBUEQBmEoa3So7brI1tW1\n4HafyXAEcDChPNpLgtkKLqawsJK8vLMpKurl7ruHtvyPJsjH6+4XxgYRa0EQhEEYyhodarsusi6X\nm8svfwKnczFDCaAu+Pv3t6CSuoYnlGbRtZGXdzZbt15Bbm4GH3xwgPLyTYO6uo8myNI0ZXwgYi0I\nQswylo06IsLoBmrC9dD79iUQbaV2dLi55ZaXQiVXh5gzJ5kvfnEdzc052O0eVq26wCSkPl8PNTXf\nBzqBDVitXoqLE44plEcT3aO5uo8myDJZa3wgYi0IQswyGrXOwyXSMcwJ3Bauhy4srCTaGl65spYt\nW24Mb9u2bQNlZQG2br0CgPLyTab7sFofCO1rBa5l+vRnqaq6Ilz+NdTLiS66+/bF09HRRGNjEeXl\nz/Doo2VHdXWLII9/RKwFQYhZxjb5Se8Y9rzpGrKzi5g7V1mpBQUt+HwJvPZaErABWIgS4AyamvrD\nK0XfB2SjOphtAdJoa9uDyzXnmC8nuuhef/3TNDSswum0sHu3xo03PoHdjsSeJzAi1oIgxCxjmfwU\n6RjWhdGSnjmzJyygRotZ7fMgUAD00NbWTmkphnGWkTXmzQvyzjsP43SuwpgxPtyXE+Vuj+xXV6ex\nY8cVSOx54iJiLQhCzDKWyU+RF4WFDBVXHmgxfw5YRHLy7TidP8XptFFfr7Fgwe8oKzPex1dYtuwt\nnM7IsVu3gs023HKsQxhfIOCQuLonOCLWgiDELCMtQEdLWIv+bvXqOUReFAI4HFcOSG6LtvyhO/T7\n01Eu7nSgiwMHMnn11SVHPba3N5He3psoLKwkK6uIjo697Ntnp7z8mQGx63nz0qmp2YCawNXF/Pky\nHWuiI2ItCMIpw2Ax4fvuu5yVK2upqwvgdicDl1FfP5nBktmGEvTaWj8eTwrKCteAg4BqdgIaHR2V\nA65F9xps3Qq9vX7Uf8dv0NOTwFlnfUJDw3SczgwaGjz4fM/z+OPfCl8DJGG1eoGDzJuXwaOPXkN/\n/4BTCBMIEWtBEE4ZBosJR7fbhI3ANYPGi4dKALvhhv+ipkYD/gDk4PMlocqyAGpwuQoHWMjmHuEp\nqGQ2C273Il5//U7g1vA1bd/+AKCEuqRkfTjWDarrWVaWdch51sLEQMRaEIRThsES1gbGndOJjhfr\nFvXWrRj27WTz5oMUFf03gcCh0PbbAQuapqGywy3ActMYzGhr3eEooa7uZdzuyDX090+PuqZsQL0s\nqPGa0h70VEPEWhCEU4bBEtYqKl41CbjV+j7FxS5TIlnEot5AJLHrRYLBVSGR1YDHMQusDzWF+OjC\narNZ+fKX+9myJXINOTkHaWszZ4+D7hno5ni6ngkTAxFrQRBOGWw2azhG3dSUSUXFq1GJZB4cjuUD\nEsn27YtHucctwL1YLFPQNLMQQzvmDG0L8Klp2/vvv0lJyRFmzQqYXOIWSwD1IqASxs49N430dHP2\nOOiegSWha0mjsLABh+O6UXteQuwgYi0IQkwQnby1atUcKivfHvFWoyfSFa2jowmIxImTk9fg9Z6F\nWZyt6CIKO1BW9W3AfcDZwBG83ttoaNhCQ8P1pvM2NxcAV4XPd/jws2zYcMWA61Cegc2hZ+LG4bgO\nTYNlyzawd2/KZ96SVfjsELEWBCEmiBbRXbsqw4lUx9NqdLDyrNzcjPD3J9IVbfLkqTidG9FLsU47\nrZDZsz1s3/4A3d2ZBAKTQmumAW8CPwXeAGzAOcBiw2rpA8473OYvg5WyRbcy/SxbsgqfHSLWgiDE\nBNEi6nJN40QSqQaznJ999vrw98cSxsHEvrPzU4yW9ZEjlfzqV9excmUtjY2pHD78AR5PBt3de1Ad\nzGxEOp8ZO6C5gJ1AO3v2OLnhBicPPbT4pJq/yDzqUwMRa0EQYoJoEbXZDtLbG/mcn38oPOSioKAZ\nSAxNtTK7fo8lXqtWzWHnzntoa8sjPv4Q3d3puFzu8PGDiX12dpGp21h2dtGAkq/k5DXARcAelCir\nzmeZmR34fHfg9Z4P7AXuBiz4/VqosckLJCWlnrC7X+ZRnxqIWAuCMGocz4jLaOty9eolrF0b+ezz\n+amuVpOt1DSsaxisucn+/QHgSeBrwOQB4lVZ+TYtLbOAawgGLWzbplFREXEdDyb2M2d2snu3uT94\n9H59fRcBS1Au7/tISSmktBQcjjKWLXuL+vqrgM2mYyCD7ds/xe3+HsNxYw/2PB2OEpKTN4Zi1tIT\nfKIiYi0IwqhxPMlcg8Vjq6rs4d+Xlm4jInQZGEVv61bYtesJnM6bUC5oNYayuHjKAPFSmd1O1DSt\nLmAhdXUBGhubqKx8m/37W4gujVq1ag67dlXick3DZjvA6tVl3HnndswJZkfC1wNTuOyyI1RVqa5j\nEeu3K+qYLlQN9fDc2EM9z6efvkaaokxwRKwFQRg1oq3PfftSB8xr1jSGZX2b3b0ejKKn+mqvRu8+\nBhamTz+DqqqBGdXRmd2wAbd7Epde+if8/n9HdR4zD+6oqKgNJ7v19mosXVpJd7debuUDmoHvh86g\nAclApP+ncQ71oUP30NMzlbi4w8yblw4khLqfHduNLfHpUxcRa0EQRo3oeGpHx14aGswZ3sCg1uJQ\nfbhVL20f8ERo3URgAZFsbDia6EXHn5XYXoXfHwh9tqLizVU0NZ1BRcWrNDamYRTJSBexxSjX9lVA\nDSrT+wPgX2lufi18zqMNJHG53CQlDS+5TOLTpy4i1oIgjBrRceh9++wGoeykrq6Vvr4pKAt1IWAN\nW4tDuXxVL20Vu1ax6eXo4lVY2EBeXvCoojdz5pFQ/LkTeDG09QXgI4zdydzun1Bfr85dWLiWgS5v\njYgrezLqheFFIAd4gYKCwYV0sLjzcEutxnJkqDC2iFgLgjBqRFuU5eXP0NBgFkTzAI3lYWtxKJev\nUbCUIK4LZYV7cDiuO2YmtX58bW0LHs9PDedfh8redtHb68bvj8S0s7KmM3euOmdb27s4nStQrvh7\ngQySklbS359Mf/9cVDvQBcBfBj3/8cTxT0bYhYmFiLUgCJ8ZRqHdv99rGl6RkuKntHR92FocyuUb\n/QJgFLSKilePWfqkH19auo36eqM73A/00tX1KZr2C4wx7Vmz+sOu+VtvPURPzyaOHPkYv//HgA2f\nL5Kdrr94NDfnDCq2xxN3PpFua8LERMRaEITPDKPQKnd2RIxLSzEJ0VAu32gB7Orq5NVXf4guaD7f\nOh5/fNmAc+vH7dsXT0dHE93d8Zhd25OBa9G05zCKqdXqxeG4ElDiWVNzI2ZvwDVEZ6dDGna7e1Cx\ntdu1YcedJaFM0BGxFgRhTDhW/HWopKxoAUxMXItR0LZvjxtwzOHDxjnQG1HZ4CrrOzPTS3d3C8Hg\nTaG9zVOtiosThmy4EkloM2enT5q0i9Wrl/G9731EtNg+/XT04BBJKBOOjYi1IAhjwtEypI9GtGD2\n9+dgtpAPDzjmpptqQhncnahJWJF4dFzcM+Tnazidk0N7LwDuwGqdQXFxAqtWXRAuN2tr2wMUo2q5\nXUyatAuLxU1m5l407S7a2s5HDez4MZdc8is0bSrwGCpbXDVoOZ77loQyQUfEWhCEMWG43c2i9yso\n8Jmszby8Vlpa9PGSrfT2urDbN2GzHWDTpjJmzLDz8cdqAIfK1r4NYzwaDvPHP17OkiVr6OubgcXy\nMZdcks4f/nAlmgaXXfY4LS0/ALYA55KYuJaUlCx6erLwes8EvkZv72Ss1gdQHcwUfv+Foc9DN2g5\nFif6QiNMPESsBUEYE4abPBW934IFv+OKKx6hrs5CMHiY/v4errjiEIcPp/L++014vSo5rLfXRXHx\nL5k9+4vs21cPfA7owezG9jFvXjq//e1H9PWpnt2appGVtR5Ng5KS9bS0fAEl1KpEzO/vxu83J5Op\nuHU2Q3U0G6pBiyAMFxFrQRDGhOEmT0Xv19xcQFvbuwQCqrlKe7vGe+9VUl9/RSimq++7Ba/3Lhoa\nLMDVKFFNxyioU6Y0AqezdStE13qvXFkbcp13o4+1VEQnk6k1580LkpS0nrq6AG53K8aOZhJrFk6W\nURfrkpIS0tPTiYuLIyEhgT//+c+jfUpBEMaI4xncEZ08ZZyqZTx2sCSrDz4wj89U4zTBZjsQmtTV\nCfQxUFQvJTPzfk4/fSYdHXvp7k4JZXfrDVKeBRIpKPDQ1FRApGb6d6huZQNbnVqt71Nc7MLh+Ao2\nmxWXy80ttzzP9u1/ALKZNy+Iw/GVkXvIwinJqIu1xWJh/fr1TJ48+dg7C4IwrhnKtR0t4qtWzaG7\n20Ni4lr6+3PIzm7i7bcTaWubA3RTX78E2ExV1VJWrDiDmprb8fnsQCtvvHGE9HSLaXympn1ISclL\n+P3dJCSsIRBIAaZjdkt3A5NJTw8wa1ZPqO3p86HvazDXSa8LvSQsAZ4DJmOxrCEjYzpz5/aQlGRs\nxLLc9EJis1l5/PFvfRaPWziFGHWx1jSNYDA42qcRBCEGGMq1HS3iu3ZV4nTmoeK8GbS3dwA/wxgH\nbmrKZN++JhYufJ5gMNKk5PDhDcTH/4NJk9agaTPw+/fh9f6UhgYbEXd3MlBCxPX9D1QG94N0d/tp\nbEwNradPwUrH7GrPCZVYbWbfvgQ6OtxkZ5/HzJlHcDiWhsW5o8NNRcXwPAmCcDJ8Jpb1jTfeiMVi\nYdmyZXzzm98c7VMKgjBGDFUXHC3iym3dBhgbjAxsKnL11c8RDH4u6rsM+vtn0t9fTmFhJU7nl1FC\nrH/vB3YDS1HWsga8AawGLHg8GocP672+F6Ji1fuARabr1jOxy8s30dCwCqfTEuopHvEWqNptFdeu\nr19CX99fSE5OEvEWRpxRF+uNGzeSm5tLR0cH3/72t5k5cyYXXnjhaJ9WEIQxYKi64GgRV7HlwtBn\nN7AntIKKERcWNrBq1RIuvvgg8CEDZ0C3As/jdPaihHmx4ftEYAZKhDOIWM+6ld1FZmYuUGmYnnUd\ncB9wNoWFDTgc14Xv6WjeAn1spr7+jh1xuN3SHlQYeUZdrHNzcwHIysriyiuvZPfu3UcV69zcjNG+\npDFF7m98M5Hv70Tv7fBhNzfdVMPHH6czY0YXjz66hKwsszX56KNlrFixMbRPN2vXfotLLnmMlhYN\nFS+OuMCnTbuPd965iRUraggGVwGfALejYtDtKCs6H7gUaADsqIEaBSj394LQmkYSME7n6u6+j6lT\nz8XpXBzeIzW1kEWLjvDwwzeZrr+oqMf0olFU1EtubgZOp41ob4DF8qlpm9Np+8z+zkzkv5sw8e/v\nWIyqWPf29hIMBklLS6Onp4c33niDm2+++ajHtLd3jeYljSm5uRlyf+OYiXx/J3Nv5eXPhePRu3Zp\n9PUNZk3G8+tfLzJtOf/8PGpqNgD6HGkACzk5Z9LfH8/evSmh7XagAvgD8AVU/PkHRIu8Emz98/6o\n78wtSW222WRkfAw8hbK+D5OW9iF7987lO9+pNrmvf/zjL/DGG5W4XNOw2Q5w221ltLd3UVjYgdHi\nLyxs4ItftFJTY9zm+kz+zkzkv5twatzfsRhVsT506BA333wzFouF/v5+Fi9ezCWXXDKapxQEYYQY\nbhnWiQ6baG4uQLXhfAyj6L3zzm7OO28PZ51lrImeDEwFFpGfX09Ly2Sik8JU0xPlyk5IsBIIGL/L\nMp1j5swedu7sBH4Y3tbevoH29qvCCXB5eWdjt3fi8/nD7u7eXo21a9dTVWUfxOWvXOdJSdIeVBh5\nRlWsTzvtNKqrq0fzFIIgjBLD7TA2VFLZYOValZVvG9qGHgkd14MxvqxpBTidNxIM3kNZ2XoaG1Np\nb3+Pnh6NuLg/cs45mZx//jq2b+/A7Y4khcFeVCMSK3APA/uFb8Bq9VJcnIDDcTnnnVdLdOKa/nun\n8/M4nUuor9ewWv/IYC8jQ7UClRi1MBpIBzNBEAZluBbzUEllt9zyElu2qGzv+nqNF164g0DgrvDn\nBQvWsWDBOmpqfIbzgJpkZaGz024Yp9kTfnHYts1Fbu6DdHUB3IPFkk1c3Kf09/8EJdQagUAPkYSy\nbiwWK0uWBHA4rgx7B1SSmwvVSjQNleR2KcqKj7QKhUMYhV+6kQljwcBZcoIgCCiLWYkUgMb+/R9S\nXv4MLpfbtJ/NZuW++y7Hbvewb18ql1/+BCUlz7FtWwuqMxiAhUDAjlH8X3stgaSkROLjm4GvEmnr\n+S7gQtP2hs9lfnF4hvb2NPr7LwJmoWnX0N9fBNRgtT5KYWElKhltOSpLfDmTJ3cDsGzZW+F72LSp\njEmTfhnabwnwMzIz/5NJk+5AJao9BbiYNy+DsrL1nHfes5SVrT8h13ZHh5vy8k2Ulm4b9BkKwrEQ\ny1oQhEFxOEro61vHK68ECAS6cLu7qK6eyvbtf+Svf/22KX5tdJmDhtO5EZXBvQG4FiX6jRgt1N7e\nZKqrl2Ox3INxUIYS2Pvwem+jokJ1MTO72l1EN1BRMenFTJv2JJ98YkH913Ynykp2091dSHV1PHBZ\nKCb9MHl5ZzNp0gy83sgLRFzcJLzefwuvXVhYyUMPXXfStdLDDSkIwlCIWAuCMCg2m5Xk5CQCAWPj\nko20ta3hllseISkpNRx/bmxMY2AfbjXVCjaj6qKTgcdR86SnAN9ATbmaiu76jhxfAPyOLVuslJc/\nw+rVc9Bd7Q0NGVHJY16U21qjo6MJj2cFSvwvBHYBd4X214UdnE7V5ASexBzbzjZdR17e2SPS1ORE\nk/AEQUfEWhCEIYkWGV2Et2/vwu3+HrqlOGXKHShhzkANuuhFiV8Lqq92I5oWaRmqLG5QrmY/8L+Y\nG5skAT+jr+9RqqsTqav7G8XF8Tz99Bxuuul5tm0zCmwLaWk9/PM/r6exsQin02ilw8Dr7yISz+4h\nM/NeZs48C7vdg8/Xbyq9Gqn49FBJeIIwXESsBeEURs/YdjptFBZ2DCjPihYZFVceaIEePpyAcRBG\nQsJdBAIbUNnZk5k82YXbbRRNN/BrlKtcubaTk9fg988kGExBNTbRXd7fwe22UF2t3MdJSQA/B+ag\nLOrvEx9fFWoN+gy7dycSEWM9acyGPiHL6+3E6707fK3p6ZVs3apmTbtc7lEpvRoqCU8QhouItSCc\nwkTHmqNjqQ5HCUeOPMJrr2kEAk7i4rK45JKH2Lv3CGoaVTdwMYFAPkbxPuusc5g5s4emptcGtVhV\n1na84RgbfX3TgY+BuahxlQuAHAa6jzNRLvUl4evs6ckIX++WLY/Q16eL8SLgDmy2WcyfH4fDsZxv\nfGMnu3dH1szOLgqvM1Q51skyWusKpw4i1oJwCjBUg5NjxVJtNitPPfUvpm3l5ZtoabkFXXgtltvR\ntHOIxH5d7NnzNh99VITNtodHHilD0+CVV+7E75+NillfS2Lievx+o4A3Ar8wrZuRkYHHE+0+1qiv\nbzadr7+8jtPXAAAgAElEQVT/AKWl27DbO5k58wzee89oxV/A7NkJVFVdBsDMmUdCAzkiDVIEIdYR\nsRaEcc5wOo0NlY18tFjqcAVe07JQFnEVqnd3F8FgJb29quPX0qWVzJ07Db//34kI8wbmz8/kf/7n\nDrzeuSh39udN606ePJudO6/kllseYfv2LiAbn6+fn/98Hps3dxAMRlzdmvYL6uvVveXn/wJz0th7\nfPRRIeXlz+BwlIhLWhiXiFgLwjgnWojr6h6guDjPJNpDWdC6cKmYtcskXMMVeNU0pNLw+bemc7lc\nBQPOHxfnYc8eNxbLLJQrfSHK9R1Zd9IkJwBJSanhZLaaGo2kpPV85Svp1NQsN5wzsnZPTz6RjmgN\nwApcLls45l1VtXTYLunhtlwVhNFGxFoQxjnRQuh2n0l19SKM8eehLGg9ljrYoITIum6ghq1bCZdR\n9fWtY8eOOI4c2Y/ff6bp/Kq1Z+RcmtaI3T47dP5O4EWCwSAtLbOAr6FqoTcCC4iLu51g8MvAEVpa\nfkBFxeZBXzSefnoO77xTidM5DeVWj2SS9/a2ohLXQL0IbEHPAt+3L/64BFjqo4VYQcRaEMY5g2ds\nm+PPJ+L6LShopr7+KVRJViK9vUuorp7MSy+t4Z/+yYbb/R3gEeAAZrezF2OrT6/XRm1tVyi2nQXc\nZth3I3ANKSl9XHbZ0/zP/2Tg8fhQXczcvPiik/nzC03rt7Q0cMstzbhc01D/hX0/tE4asAO/f7ph\n//0YG6h89NEdfPnLTtzunzAcAZb6aCFWELEWhHGOLsR1dQHc7kkol7Kyns1WpMbTT885DjduIsZy\nLF1Yvd6LeP31N1EW60qUtfwEaiBHO6qLsdFFvQGP59rQPtEzoPuA5wgEGnj11UmGLO6rgY34/d+n\noeEOpky5k9bWTCCHlpZp1NT0oCzq7xPp7f0ukAt8E3gUVfaVazqf13s+Xm8iwxVgqY8WYgURa0EY\n5+iubJfLTUVFbbhcyuG4nIqKod24RiEvKurh7rsvNQl5c7O5bEpZyhpwhP5+O5GuY1ZUE5PridRG\n3wecg5o9nQc0AR+iyq6Mk7KSgCX4/Xpf8IENWDyeWSQnt2O2yO8AZgPVqBeEbuAW8vP/MzQ+MxX4\nDip2bbT6+1CW//AEWJLRhFhBxFoQJgjRtbwdHW7q6gIMZUVGx2O3blWJafooy/37WzAL3QcoUfwq\nmnY/0EYkVmxsF2pDCfWi0P7LUeJ9F8oK3xD6eRi4OXSMGo9pPp9qwKJp+4AZmIXc+HKgERdXyeLF\nm1m9+uusXbuerVuht9eC8jJsJDXVj9V6EKdzRegY87jM4T5TQRgrRKwFYYKycmUtbncyRgFsa3sX\nl2vOoCVYbvdsqqt7eeGFZwkEfoBqevI4CQkHiY930dcHyjLdiKaBGtDxBMpSNSd5KYu6m0gnsjwi\nVvi1obUnh36BalGqhFUJ//+GjrkTTZvOpEmfmu7DYslG0yLXnpmZHxbVqio75eXPhLK/rcByFi3a\nyN13XxdOWLPbzeMyBSHWEbEWhAmKEuPLiCR7fYDTuWKISVYaKuY7hUAgDlV+dROwhUDgCwQC/wvM\nAv7VsP8TKAs3A9iHxXIP8fF5BAKTQts04K+AhylTPqa11Xiut4F+LJZ7yMgoJDHxQw4f/hg4HdUi\nVG9peit9fRZaWlwUFlaSl3c2druH7m6LqT/4vHlB071Hu68ffngJ/f3xYiUL4xYRa0GYYOix6P37\nA8ALRMqjGgAL+/bFU16+iX37EigsrKS7Ow+PJxXoR8V6VwHPM3Bs5YOYXdFeIq7opSQn34HFkkIg\ncD1qulYkOe3ccx8hGFxDe7sVSEFlmF+IpnnxeBaQmPhbIuVWoCzvFoyu9by8s009vCsqjLHkr5ie\nQbT7OitrYGmaIIwnRKwFYYKgi3RdXWu4NElZqA8CU1GZ0y/S0dFEQ8Oq8PcLFqwjI8PCn/6Ui7KI\nLaj4cXTCVzbmmHIbcC9wJtCL16u3En0KJeSRY1991UJcXAoqSWwjymqPZJn7/YVRax9BxbUHTwST\nWLJwqiFiLQgThEjC2POYRfZzKMsYrFYv2dlFoVnO6vvm5hxefPEqtmypxOPJRAnkQuBhzHHoD4B7\nUOVQn6Cs9dNQ/43obvR7UWIcB/wB1VAlm2CwjWBwNsYs78j1pQE+EhPvxO+/ECXUXwX+jB7DLixs\nwOG4bljPYbCmJ7m5GcN/kIIQg4hYC8I4xihMjY3NKGs0Oqv6g9C2i0lNbaGpqdXwvYuWlgYuuiie\n1FQfHs8hIsM0koG1wNmo+dR+4N8M696DuQ57PxExVo1U4EbD9/eGfkZf35vArcyf/0fee68Bl6uA\nYPBBEhPzSEhwMW9eBg89dJ0pGexoXcgG6zr27LPXj+RjF4TPHBFrQRjHDBxxuQFlFW/AYulE0yaj\nksImAz/B6ZyDsnZ19/X7tLTcTkuLGiepYtjxeDyRrl9qNvVUVDmWbhF3hn4+jxLfhahY9FMoV/jn\nQvsaLegzQ9fXgYpPzwQ+4owzTufsszfj82XidN4aPm9f30ZgOe+8U3nU+46uH5euY8JEJG6sL0AQ\nhBMnWpiURfssYCEpCVSZlJVIzPkaVLz4Z6i4snnSVV7e2cTFTQltawLuIxA4DTW+8gPUCwGooRv/\nhnKTXwO8SE5OV+j35aiMbo9hfw34O6rL2b+EzvuvQCVnn51OVdXSIZqwdOJ0JvGlL71MefkzuFzu\nQe/bKMh2ux7rVueVrmPCREAsa0EYx+Tnt2N2KX+KsliXk529FqfT+F02A8XQQ3QS1/79h4hY6SsN\nx98R+mVHZY4b1+rC69Vbieq11A+jeocfRjVKuZWMjN/j9f4Wv/8H4WN1oR28x/mLwG243Zbw1Kyf\n/ewC3n//TZStoWq5jYIsXceEiYiItSCMA4aK0VosAZRL+xxUYtZNJCb+ioUL13PTTZdTVnYHXu8Z\nKBHXk8f0lqC7gHxgLSkpUykuDuDz+QkGe1FCrTcyIfTzDOA6VH11E+aXhAz6+oyNS14GvoDKLs9A\nxbxtBAIF5OYewOmcjHLHv0hjo5fzzvt/ZGbmUlhYSUdHPl7vx8BZKE+B2YK++urn8HrvDp970qQ7\ncDi+G35WkikuTERErAVhHDBUjLa5uQBlyR5BWco1zJ49i6qqpZSXb8LrvQtdnJOSKklIuJ2enjhg\nGiqurGqws7PvIzm5kOrqG9HHWMI+zILsDH13AGVdr0G9JAAsJCnpMIHAGjStCJVsdhvKotbLxzR6\nexPp7b2JwsJKenoScbt/gsdjwePRcDo3AuUUFq7F6bwRlQneHzpeXdP+/V48Hv2zcu9bLGdIJzJh\nwiNiLQjjgH374ol0IusKfdZdx06MYyA7OytDx6QSsUq34PPdh893H2bX9qNAKocPF1BX10JEBK8F\nqlCZ4fkogc5CDc643XD8htC+GoFAK5p2t+E7NaVLfc5An1kNVvLyzgagvn7g4I7U1Azi4n5PMHgm\nyoJ/mMREF37/atzugee12Q6OwBMWhNhGxFoQxgEdHU2ozmJKrDo6lCA7HCVs2/Ys3d0RIXc6M7nh\nho20tzehRk0aB20UYnZtu4Dv0NtrobdXL69agcoeTw/9Mo67/H3U8T5SUp6gtBS2bp0V9V1a6Pca\neXlO2toy0NuPFhR4SEpKHSRGrXHwYDvB4C8M2+8jIeE0/P7I2gkJXSQmPoHNdpBNm5aMwBMWhNhG\nxFoQxgHRjUy6u/MpLd2G3d5JWlob3d03YxS3mpofkJFRScQa34PK3Nbjyrqr20qk3MsKTCUh4X7i\n4vz4fBehksOMAtwedbwPTfuE1auXs2tXdUjw9VjyLs48Mxjq5Z3Ntm3Gmux14USwxsZUDh/eS1aW\nnVmz1g8i+oXYbAdMa3/taykSlxZOKUSsBWEcMHPmEXbvjoiVxzOJ+vqrqK/XyMy8n4Edyyx0dWWh\nOoFtQcWoVwE5KDd2GrAas8t6OZBIIHAPSsD/GZXR/RyRCVr5qASzT9AbpHi9LoqLf8mMGbPp6FiD\nxTILm62ZTZuWMWOGHYDS0m2ma2xuzglN7oL4+ATmzp2KwzEfm83Keef9P5Mwx8W9z2OPLeI3v5EM\nb+HURcRaEGIUYwZ4QcERFixYR3NzDvv3f4jbXR7ay0JcXA4DO5Y9h7Ki70fVNH+CyuZuAaaj/ukb\nBb6XSExZjzHXYIyFq45lFtRkrFzD8Vvweu/ivffUfmVl66mq+qHpXqLLsvLzD1FSsh6n8/NAN/X1\nSwA1DWzTpjKKi+/A650LHCEY/Cm/+c1msaSFUxoRa0GIUaIzwBcseCRUB52NcZqWGg+5jr/+tZfu\n7k+Bi1CW8OmoxiMbMVvRG1ATuIwCvxeoNHzuIjLUg9DPPOC7od8/aTg+zbSfPtXLWGbmcJTQ17eO\nHTvigMP8/e+dtLYas8U3huutZ8ywc+aZc0ICrpAuZMKpjoi1IMQI0bXU+/aZrd/t27twu7+HLqiZ\nmfeTnh7gwAE7s2YFuPRSqKkxCq4+0jJ6cEYGkEpy8hr6+opQJVnXAveRklLIxRd3smePO9yCNLJe\nu2Gdr6EsbTvwIcaBH8apXvX1Gjt33oPXO4nu7kwCgRRUh7PJRJLZrEAaBQVt4WcRbYlLFzLhVEfE\nWhDGEKNAt7Xtwem8CbBRX69RWFiJ2fo1dyCLi8vB6VyK07mFhgYbCQnNmEVZH2kZPTiji4yMOI4c\nyUEN2zgHZWmfTmlpAJhMS8vNqCSyDajGJMmobmcuVAw8DeU670MN67gPOJ1Jk97D5TIniLW0nAbc\nYDi/XtJ1DsrVvhyVABeplT6RLmRHG+4hCOMdEWtBGEPMgzjK0OueIZ3u7ngWLPgdzc0F2O0efL5+\namoiohsMtqISwFTcNxBIxizKHwLrAB8JCXeQmmqnt7cVvz+Prq6bQsdGyrL0TmDLlr2FuW3oY6hE\ntXZUDFwvq1qMst5/HfrcH2rCsiHqOj7GPPAjncjMaj9KvFfQ3Pxa+LmcSBeyW255iS1b1JSv+noN\nn28djz++7LjWEIRYRcRaEMaQgYM4VN0zWPB4FvHOO5U888ylVFa+zYEDqRQWVpKdXcTMmT3s2NGD\nx6N3KFPlUCrTuwg4hEokawb6uPLKqTz++DJKS7dRX38VqtXnFNO5LZaZVFS8SkGBL6r+OQnV4/t7\nqKYo0Znnt6EEWo9xL0QJsB/1wvBjIrHpDSi3ezfqBaAGZWWfvKtbxcONYQOZUyRMHESsBWEM0F22\n+/cHUMlaKlksISGDQCAiOE7n5/n615/D6VyFXtvc3d3K4cOddHbOwFwjnQccxOxyvg/4N/7+919Q\nWrqNtrY9QDHKlW22xHt7J1Fd/VVycx8gLq6SYPBclKguBJ4hMfGXTJ6sceiQUcg7iMTBdXe7FWWx\nbwQuQAm1up+MjB4uuSSN5uYUCgr+Avhpbn52hMqx9AEk+rUdPsn1BCF2ELEWhDEgeg611foAxcVT\n8PniTK5u2EN7+ySU8H0K3IbHsxGP5ybMMWA97mu2llUi1ye0tEyipUUD4oiLewzoIRj8FuamKWnA\nb2lvn40S/UuIWMQp+P134fGsJGJFd6GsZz0uvjD0nQc129ofWidyPxkZbTz00HWjEkueNy+dmprI\ntc2blz7i5xCEsULEWhDGgGj39/TpZ1BVdQUul5va2kiNMXyf/v77gVtRcd9OlGgbY8CHgZ8Dp6Hi\nwy4iItsM/BfG0q1g8FGUqNeiXNyXoBLMsgFzJzRlraeg11/7fJ9HxbHdKBd2H6rZyixUL3Er8+f3\n8NFHHQZvQCRJzelcQUXF6NRMP/TQYpKSamlq6sduD+BwLBrxcwjCWCFiLQhjwGClSR0dbm699QW8\n3lTgXVQf799hsVhD+3Whz3c210w7iSR96f29P48S4FuBNxgqLq72vxPlEg9gdqufTWLiLvx+Y1xc\nb1eqZ3FvBCJWfn7+PVRV/V+WLXsr1B5VT1J7At3CHq2aaRmNKUxkRKwF4SQYrFxI0zhmCdGqVXPY\ntasSl2saNtsBVq8uY+XKWmpqMlGCGBHI/v5VKKH7J1Ss2Si8XSir1rgtK7R/AcrCPow5lpsVtX8i\ng7ce3cP8+Vbee68y1GnsCPA14uJuJxiczWA13IcO5bFs2VuG2Lhu4SeG1tyA3R44mUcuCKckItaC\ncBJEdxnbseNOLJZEWlq+SHQbTSOVlW+H3MRq2tXatetDFmc8oAshqCztIpYsWU9dXStudwCz8LpQ\n7m/jtkmoGHRC6LNuMetx5vej9j8Ns3gfAdYQF5cNTGLTpq+wdu3bNDVl8v77/43X+wsi5VnmGu5A\noIv6+u8BZRQWqpeR3t5EdDe61erF4bhyJB69IJxSiFgLwkkQHXtubc3E7KbeyL598dxww5Ns394F\nZDNvXj8HD9pMx+lWeH19AsqtHRHA5OSPqaqqoKTkJdzuuURiyZ+gBnTEoVzZXyAu7m0SE0/Dau0h\nP9/PO++sQpVyAVyKcksfQrnN1QuFwije7cDdBIMWtm1TLxL6y4YqrzKWZx1Cud3PRDVJ0T0IFvLy\nzmbu3E6qqyO13MXFCdKoRBBOABFrQTgJomPPaqqV0UpN49Chf9DQMBNVp2yhpkajsHAtRoFsa3uX\nRx5ZwvPPr6e/H2ANMAP4kOeeUz2yOzo+QM2n/hmq3KsIVaOs1igsrKS2dkVYDE8/3YHRnR5xbx9C\nJY3prURdpKTcyec+d0FoSMh0ol8kdDIzP6S39ymUlR4EWiksTCMvz0Jb236czhWhPbVQOdbxdyIT\nBGEgItaCcBI4HCXs2mWM6YJRhAsLG+juzid6KEZW1nSCwXtCrTgP4XTm8POf/5WMjM/hdn8ntJ+b\nxMTf8KMfHaCx8QX6+rJRTU+CqCEdlqg1i/jRj14KNQc5hNc7jYHu7TtQGdr5JCSsITGxCJvtIK+/\nfiOZmVmUl3dSXR003YOxWcm5506ltdU8lzovL4etW6/A5ZpDRcVmkzBL0pcgjAwi1oJwEthsVmpr\nr6OiQh9l6QbWceCAlY6OvWRl2Wlvfx9lyUYEsKOjiba2eIwNTF555U5SUuyoEqgkQMPvn857730F\n+CbKMs4Grg8d86RpzY8+eoeGBqMlvQqze/sQytJ+ArievLxK6uuVkObmZtDe3oXDUUJX1yb++te1\n9PfnkJfXyurVXw/fb0tLtOcgKyzmgwmz9OsWhJFBxFo45TlZQRlMpMrLN9HQsCpUvuQCfoXqo51D\ncvJenM6fAq9hFD6//zz8/q8DT2F0b0cGX6SjMrvNk6+s1qmkprbgdJ6FWUhPJ9L0pBs1IUvPFrdw\n6JCV8877T7KzizjrLB93330pNpuVjAwrfv8PUUM4VMz6vvsms3JlLR988AnGF4BJk/6Ow/HdIZ9N\ndAIerBdLWxBOABFr4ZTnZAVlMLGPTjxLTEwmISEPm+0AaWmz+fBDGwOzsveimo34MIuuPviiG5X8\npR8zGYsFtmy5iNLSF4nUQOvrdaJGUBpFXwM+ADz4fB04nbfjdFrYvVtj69YHKC7OGzCas6kp0/CM\nzE1OZs8+86gvNtHPQeZSC8KJIWItnPKcrKAMJvYFBUeor9cTsRrw+1fj96syrbi421GiOR01ZcuF\nSkzTUIMyEjGKrsXyDzTtDSAXaMNYhlVSMpnKyrfxeH6KEtInAC8WSxslJalYLI/wt78l0Nv7CX7/\n5NCx/4pqQ/p703273WdSXb1owGhOu91jeEZ6k5PNwCJmzVp/1Gcjc6kFYWQQsRZOeU5WUAYT+4IC\nH2ZXduT7YLAIZeUeRDUNKUSJbyJKjL9NxH39Ppr2A5S4bgD+D/AUcXFTyM9vZu3aMr73vY9QQl2D\ncnHXk5CQQnp6DqtWzaGy8m2ami6goaGVQOBaw5V7MFvi3YCFtjYbmZn3Ehc3hXnzgjgcX6Gi4lXT\nM7Ja36e42HXM7G7JBheEkUHEWjjlOVlBGUzsm5qMiVjdmEVxHzAXZSk3oTK0dSt6DZo2GX1spGpu\nYkW5x53AfwM/Ixi04HSqeLLdrlFf/yKq8cgW4Iv4/X+jujqVHTueprVVTzp7LOo6JqFqtnWL/VpU\nYxM3Hk8u8G2SktZjs1kHeUbLhxXXl2xwQRgZRKyFU56TFZTBxN5siS5g0iR9OEcD5vnOD2O0ujVt\nKuaksNND3+k9wfVhHjVAOi+++AlPPTWX6upGlFDrDUgWAxtob3cb1r8KPclNZZtnYh7c8SAwFfg+\najZ2JCQgoisIY8uoi/Xrr7/O2rVr0TSNq6++mu9+d+jMUUGINYzJY0VFPeGM6ejv7HaNp5+eg6ZB\nRUUtH3zQR1LSSgKByWhaFgkJCeTkvMGhQzMwzneO7tsdH3+Q/v7lKOFNA/4GrEe5rI3DPJSL3e9f\nxHXX3YHqIKYnkaWH9rMQF5dNMOgCnkPVWbtRZWT7UF3QjIlsn0OJPOgxdIkxC0JsMKpiHQwGufvu\nu3nsscfIy8vjG9/4BldccQWzZs0azdMKwogRnTzW1xfJFDd/52LXrofp6cnH7U5GtQA9D11Uu7s1\nurvvZaBLvBdju87MzG5crl+i3OTdKGv6d8TFHSYYfCp0nC7cABb6+magyrgexNyx7A4CgWyUq/si\nlBv9bsP390ZdS1doTY3MzGYuv3y9xJgFIUYYVbH+xz/+gd1uZ+rUqQB87WtfY9u2bSLWwrjhaJni\n5u+2hAdzRFzKUzBbrlNR85/10qckoAKYTGbm/Vx+eT61tfmodqLGcqtzCAZ3EElYM2drJyU10tc3\nGbgg6nznh873o9Dn56K+n0JcXCWZmflcfHEQTfPT3PxsyJX/LWleIggxxKiKdWtrKwUFBeHPU6ZM\nYffu3aN5SkE4LvQZ0sYhGw899NWwUB0tU9z8XRpmIcxhYLb1p6h48JbQfsbM7CyqqpYye/bTUesk\nomZbzyYya/paVO/wmUAj55+vMWXKeurqWnC7jefrwzzCMtqqn0QwuBq3WyM9fSO//vWiE3+QgiCM\nKqMq1pqmjebygnDSRGZIR4ZsJCVFXN3G5LGiol7uvjviFjZ+19a2B6fzUpT1qgGfEB9/iJSU/Xi9\nOaSmdjB3bhJJSX/hwAErDQ31GIWzp6cJgNTUZjweo6C+jZqQFT2M42z07O3333+A555bSmNjE5dd\npiey7UG9GNQYzrMAWE1CwukEgx0Egz8I3YmFl1/uw+VyizUtCDHKqIp1fn4+Tqcz/Lm1tZW8vLyj\nHpObmzGalzTmyP3FFk6nMdlL/Xz5Zbj55s08/PBCiopO49lnrx/02I8//pitWz/C651OUpKb5OT7\n6euLCGt//4P097t5//2vMmuWPXzcsmUbaGiwY8z61rQscnMzyM+fTUuLMRu8ELOl3YNyg98U3max\n5JKbm8HNN+8OCfUSYD5KqA9jjImXlc3i2Wf/lWXLnuJPf5ocWkPD5UpizZo3ePrpa07mccY04+3v\n5vEi9zexGVWxPvfcc/nkk0/49NNPyc3N5YUXXuCXv/zlUY9pb+866vfjGX1YwkTls7y/kRoQUVjY\ngbI8jVauxp/+dA11dXdywQWn09ycg93eyaOPltHfHx8+trj4v/F6VUJXX9/AMiz4HL29izjnnDWc\nddaF4evcuzcF1RAl0go0Le1+PvjgAG1tjcBqIpb0PZhd1wdRLvGI0H75ywHa27tC6+qubiuwHKv1\nAdzuVeFrfuGFRzj33Cc57bROMjPvx+M5K3TMQvbufW3C/v2Uf3vjm1Ph/o7FqIp1fHw8a9as4Tvf\n+Q6apvGNb3xDksuEESE6S9vne4SkpNTjFm+Ho4QdO35Pa2ukhSf4AQutrZnU1NwYPseKFea4rsrC\nNoqzLvzmjmB9fRdRX78k3IpUNTHRO5KloHqEZ1BS8gRO57+gLO40EhPfRNM6CQSM1xYAesOJYfPm\nBXnooa8Aegx9Sfj4KVPexGJJQLnmu4EFBAIZNDRYaGhYQWHhWjyeReHrLShoobx8k0zIEoQYZNTr\nrOfPn8/8+fNH+zTCKUZ0lvb27V243SrufDzDOGw2K7m5X6S19RuGrZtRYmseB/nxx+mmYxMT38Pn\n08up9gNWLJZVaNoUVCb4wtA6R8JrNDVl8vTTc/D5nmf79k85csSH378aj8cSilXrE7bgnHOCfPCB\nJ6pF6BPAdSxePPD+VAxdnyftxuc7Pfyyoa7j58A09KSzrKzpzJ0bicd3dSXIhCxBiFGkg5kwLonO\n0lZzno8+jGMo13lHxweYLeJ/oCxRv2n71KkdpvXmzTuNurprUAKryq1UUuUToWP+GlrrJlQzkhfZ\nv99LRcWr3HnnpVRWvs3WreD361neVlRWOeiZ521tB+jtjVxDYuJHLFwYqX8+WjigtHQbZsv/QmAR\nqu5aY9as/rAY5+ZmcP75zx7zGQqCMDaIWAvjkugWnz5fPzU1Rx/GMdQozKys6TidxqQuG5BGUtLf\n8fkiLmhN8wMRgfzb3zKJuLKNomhDJXlp5OfXc/75f2H7dhdu909wuy1UV2vs2lUZVZetZ3nvITPz\nAOnpnTQ2FnHWWekEg/fQ2WnHZjvIpk3fZMaMSLLa0cZ75ucbx2lG3PLJyZlkZ1fS2FhEefkzOBwl\n5OZmyIQsQYhhRKyFcUl0r2qXy01SkhLv/PxD+Hx+Skqeo6OjiezsImbOPGKY0+wGati6FcrLn+G0\n047Q0BA993k+gUAzxlpop3MzYBbIwTuBvQtYsFrfp67u/2KzWSkt3UZ9fUTQW1ryMQu8L3TeFfT2\n/gaPZzVOp1qvrGxod/TRmrZYLAHMDViUWz47243TuSo8xxrW8+yz18uELEGIYUSshQmBUbzLyzdR\nXX0jSvwiohSZ01wDLKe3V1m5Cxaso6xsPXV1AdzuScDngQcJBmcCT6JaeU5mxoxuYKBAKsv7DlQ8\nuAOYDnhITvawbNlb2O2dFBT4TFZrMNiIWeCdwCpAw++fynDd0UezhpubC1DDO9TLSUrKc5SWQmNj\nUagP95kAAB2VSURBVOhFwLy+DOsQhNhFxFqYcETE1Ni9y0J2dhFz565n61bo7Y1s37LFD8SRlLSP\nSy9NYceO9/H7jT221wCn8cYbbXz8cdMQ8fJ/Ae4HvoxyN/8Tra1NtLbGU1+vYbM1oAR9BtCIGk/5\nKOACckhI8HHmmU9y8KATtzsXo5C3tb2Ly6WGhETHp49mDUeuU5VxlZYqC728/JmQRS3ubkEYL4hY\nCxOOiEh1YU4QcwNJJCe3mJK21Pzoa+nr0/jb39bQ3z8Ts+V8EbAEp1Nj6dJKamuvo69vHVu39hMM\ndqFqnp/D3GnsTuDfw59drmYiPb9dwC9DP28DLAQCGtOmraOjw4/bnYkS9rMAC07nCioqlAteud87\nqa9/kc2bXyQ//xCbNpWZ4tg6Qwm5uLsFYfwhYi1MOHQx2rcvno6OylDMugefzx9yj3cCG7Bavbjd\nzcC3ULHddPr6JqHKsIyWc6T0yuWahs1mJTk5iWBQj1u7gD9hFvjZUZ+Nru0tqOlYz5v22bEjLtTA\nxAIsxVjGFXGFW1Bu/GsIBi3hF4j6+h8OeA5DubXF3S0I44+4sb4AQTgROjrclJdvorR0G+Xlz+By\nucPf6WL05z/PZ+7cacTHJwAaBw7o7nErcC3Tp2eRnNwN/A8qE/tS1HCM6cDtKOt3FZAMPAW4sNkO\nAkZXuxvVuSwdJewQGdox1Gd96EdX1D6HMQu8uYzLbu8M7Wd277tc007gCQqCMJ4Qy1oYM06mZejR\nSpaG2ieSYBaJ1WZm5vL66z6MFmvEoq4M/VKfU1LuZNOmbwJGV3sNKiFtPpFe3/XAdeidxPLz/8E5\n56Tw1lsPEAza6OvbT1/fYlR2trLwi4sT8PnSTOVn8CbgJjHxQ1avXoamESr5ysKY+Ka/QAiCMHER\nsRbGjOEI7lAcrWRpqH2ys4v4whfWsWNHHHAYny8Nl+t0Iv20zRYrmMurZs/+AmvXvk1T00cUFPhY\nsOB3vPZaGr293ai49TWhdeopK3s93EnM4bjB9BLicrmpqNBjxgEcjiux2azh8jOVld4K3ArY8Ps1\n1q5dD2CqzY6LqyQ/HzZtWjKsZyYIwvhFxFoYM4YjuDC4BR6dkd3W9i6NjbOprHw7FKtuort7CkYL\n9PDhvRw4kIjb/RNAjcMsLFwL5KFi1p+iOnzplu1HGC3xDz98h927fwxsob5+Cvn573DxxbBt23J0\nK1qNpnRTVXXLkPetu+n1+9LLuxyOEqqqluJyufnSl17G7Y5MBDPHrNXPL3zhbLZuveL4HrogCOMS\nEWthzBhux6zBLHCHoyTkEv48cASncwVf//rDIctT1Vfr61qtD5Ca6sfpXAG8gVHwsrKm09PTh9t9\nLSr+vBHoBeJJT7fS3R3pbOb1TkElhy0HOmlp6aalxQU4UAlk76EakMwIdwY7mls/+r5eeukOzjjj\ni8yceYR587yDdGTTpMOYIJyiiFgLY8ZwSog6OtzU1bWiMqe7gIXU1QUAyMs7G6cz4gJWiVadKAs5\nsj9kk5WVHJpdbSznctHR0YT6Z2BM9Ipj0qQP+dKXckNWs3EQhje0dgORUiy9i9mtqBj2tVRXR9z6\nQ8Xmoz0LXu9cdu9ewu7dkUYtA5+NlFwJwqmIiLUwZgynhGjlytqw21qJ4gbc7klUVNSGRk1GLE2b\n7QC9vS+i1y4b909N3R/6HEnqSk1tCVnincCDoTOqY71eDYvlEcrK9CYqicBpgHGKlTG+fQ7K6s4I\nb9Nd10PF5gc2V4mUiG3fHsfOnZcPsMyl5EoQTk2kdEuIaQa29vQBC2lqysThKKGsbD3nnfcsZWXr\n2bSpDKvVO+j+2dlFoX1fo6wswM6dV5KXdzaRUq5CoMh07JtvJlFVtZTSUg3l+p5i+F5PSoOI0Kah\nLHe1TXUecw8am+/ocOPz9ZCYeCeqocq9wFfDx+ovJIIgCCCWtRCj6K7j/ftbMDcoSQYmY7d7BrXM\ni4vfCrmgzfvPnNkzYF+zZbsA1S50cfjYI0eacbnchvi4RiQBbQEqLn4xSqi/Sn7+b9C0Plpbn0OP\no1dUbB7gAbDbPaxcWUtNzfdRVv2LZGZm0Nv7K/z+81Gu9oU0Nb02sg9VEIRxi4i1EJNEXMeq21hm\nppf09BaysuzMmrWeVasuoLx804A4sB4Hb2xM5fDhvUPuv2LFGezc+Qlxcb9H09rQtGyU5RwZien3\n9/GlL71McXE8mzYtYfHi/6Kt7V5UMtmnXHppOllZ7tCam3E4bmDZsrdobdXj6CrePm1aIYWFkU5q\nDsflLFv2FsYGLTNnPovdnkF19VVE9wQfbu25IAgTFxFrIWYwJmIpi7oTo5ht3fp/wvuqyVqROPCO\nHXdywQX/v717D66yvvM4/s4dSAI5QIBEuiGAEay2TC11YVxCsY0SwKBopXWkRZuV0sEx7Qw3124t\n3VBTrbZDhyJip1AqWNYkUAhVA4RWKcvWTTEqZYg0CLmS5DQJhlzI2T8eTs41yUlyDufJyef1jyR5\n8jy/x4if/G7f379QVTWelBQb+/bdicVyj5frjbra+/f/DZttKvZtXUbFskSMFd23AGeBWVitVyks\nXAgcYO7cmRQUrMAepnFxO/rorR/qPsMabMye7VhwVlv7IcYsVAuw8PqCMc8V7mvXHtA8tYgorMU8\nPM+Jfg3jPGnPbUru88A1NaMpKjIWf5WW2igpeZ709AleVl4bVcpsNvszXgVGYdTyjsGo8x2O8yEc\nsIeKitFERUW4PLOqarzHO2zYcAenTm2msXEyHR2tdHZ67iNft+6oS3GT5OTN5OU9isWS4LHCvbfj\nMUVk+NACMzEN9wBOSLjavXjMfZuSo0421/853uV7rdYZFBau6F6k1VNdbSOclwMP4Dhw4zxGr95+\nTSy1tR9y7twZl2c6/wJhr1V+773/Q2VlCq2t99HZGef1evf3nDDh1u6hbvf30l5qEQH1rMVE3Lcy\n/eu/dhET00RFxWjWrj3iUmTEfcjY4LywrAXn3qx9LrukpBqr1blKWTze64I7evUjRpyisvJx4G3g\nBSIj4/nqVyPIy3MMs3uOCuwBFpGQ8DyTJ6fS0HCW8vIUsrPfICmpvcfiJjq+UkS8UViLaTiOthxF\nQ8NZ3n13NE1NI4H5lJaOwbl2uMWSwNGjj/LUUwc5caKZrq6RjBr1X3z6adL178nEOQjtK8dd63I3\n0dLSRXGxZ4979OirhIe/CtTT1TWRq1dPYN9j3dlp429/20xj4z9Zu9Y+x96Ja489DhhDevpE4FPK\nyjZQWRlGWZmNhQt/1UPBEx1fKSLeKazFNOxBlZ2dT1mZY07Xfq6z+/ytxZJAdPQorNYngDCamowg\njI6OoqLimNeeqc3m8hG5uf9Gbu4ujh6toqnJ0eP+9NOP6ezcdP3j3TiOtQQIo7LyNpYuzae6ehoQ\nAdTg3LNPSDhDenqj28pv43urqpJU01tE+kVhLUHR2/GYnoVQjLlfb/O37td6C0LnZ9XWfkBl5SPA\nCUpLLZw6Vcirr36ZoqIyjCpmxtx3Z2eM030XERb2PDabYw82NFJdzfW2NQNfJyrqP/nsZ79w/ZeE\n5S7z0KrpLSKDobCWoOjteEz3cHPupdo5iqZ0Ar/AmKOezJkzZzl/fjqpqSlenwVZwHPAOowe8hKW\nLv0B7e3P4dqTHwf8DmNOu4nY2Gu0tPwEo6zoFYzKaP/h8j2xsVO89pg1Dy0ig6WwlqDo7XhMz3Bb\n7lIYpKHByoIFu64vLmsB/oF9q9XVqzbuv38zpaVrenyWUVrU8XFbW6rb12MxVoR/B8ee6k20tKzC\nqP8dS2Rkk8u2LIhlzhz7QjdXmocWkcFSWEvA+XIetfPQcF/h5r5PGTbjHLaNjZNdnlldfRqoAyYB\nTURHv097u+PZMTEfc/Wq4+Pw8L8QE3Mzra2OeyYm3sq8eYc5e3YkKSlW2tvDXY6wTE4u46WXHvXr\nvyNVLhMRO4W1BFxP51E79557Kh/qjWtP+Z/ANYzDMIxqYBbLRS9D369h1P22MW9eM7Gxjmd/97uZ\nfOtbRiETi+Ui+fnfIDfXtcb41KmfsnfvCurqjIM6GhutREc79/4fHVS49jYtICKisJaA8zbk7d57\ndi8f2ltYTZpUh2Pl9SGc545HjPgB+fkP88QT53Ad2o4HrEAR77wziowMG3v3Oupul5be7vKMvDxj\nq1hP88z+HtrubVpAREQVzCTgfKnK5R5Wb74J2dlv0Nho7b7GXiXs3XerMHrKBzAWejm+b8aMO0hN\nTfFS4awZo/DJclpbV7hUN3O/f0ZG8fUiLF9mz547AFi27CSf+cxmFizY79Euf1DlMhHpjXrWEnB9\nrYb2drBFa2sUhYXLgV0899yXWbfuKCUlnVitMcDNGNXGwFix7Riurq39kIwMSEq6wsKFO6iqGk9S\n0mWgg2PHYl3mod17r84nfZWWHqKk5C1Gjap2mR+/eHEPZWUr8PcwtVaMi0hvFNYyIN4WRCUmxvf6\n9Z7mdD0XjD0HrMIeqJ6lPH+CUdP7MAAjRjzDzTfPor7+LJWV36Gy0kJpqY2srF28+ebd3W2Jiemk\ntXU39pO2ej4cxCg9arWGYbXux3PPt/+HqbViXER6o7CWAfG2IMo4PrLnr/cURp5bq27FOBrTGA72\n/PoM4JcYx1oa27UmT95BRMStVFZauq9zPuXKOewTEp4nPX2i18NBjLY6lx5twbPmuIapReTGUljL\ngHhbEFVfbyU7e7+X86h774m6b+OaNOk0V69eBuppb48lKanNrUjKOVpaEl32OZ84EU56uvftYO5t\nnTLlZrZv77l4iethHwtJTt7MuHFpNDaeIyHhM0yb5jgFTFuuRORGUFjLgHjbJ716dZHP51E7y8tb\nQFvbDv7yl3CgHputDat1JRBGUZG3gy+Wc+edr2G1Ovd468nLM+a43ed9fS332dNhH/ZtWYmJD3Zv\n3bLTlisRuREU1jIg3vZJL1z4v7ifRz1lSkGfC6YslgRiYqKxWpdgzEN3YAR9JpDgtd73nDlxFBW9\nhrElq5k5c+KwWBJYv/4LLFu2n7//PYk//nEbqak3M2WKY7GZL4u3+jN/rC1XInIjKKxlQLztk25s\njMJ5fjc9PdLrcLOd8xCyMWy+H1iBo7e8B1jutSf80ktLiI4+SkXFNVJSOsnLWwzAsmX7XRarffTR\nHj76aEX3YrPBcB7m96USm4iIvyisxS+MHuV8jIAdQVTU/1FefgvZ2W/0OI9rDCHbe9MzgCqce6kj\nR3aQkbGrx+pm3nq/jY2TXe7hz9XbzsP8PVVi05YrEQkEhbX4hdHDHIOx//l3dHQ8S1lZGGVlrvO4\nnr3p/wYex3FutKOXmpFB9/nWvs4LWyyf0NoamNXb5887rxL3XolNRCQQFNbiF3l5CwgL28mxY9do\nauqgq8v7PK7nnukXcD43Oioql8jIz2CxXGTjxvsAKC8fhXNIfvzxKI/n238JGDNmOg0Nz2CzJREW\nVk1q6nTS0nb5pcebmtrMqVMa8haRG09hLT5raLDy1FN/vL5q+zJz5sTx0ktLsFgSsFgSiI6Oxmpd\njrE4zHuouS/IioyMp7PTfu0YOjpS6ej4Bq2tNnJzd7F9ewoNDX93uV99/VngHpe2uf4S8DWysnax\nffsK/Gnr1kza2jTkLSI3nsJafLZu3VEOH7YPWdsoKnqN6Oij3cPAjmHiTGDP9TlnXELNfUHWqFEN\nxMUZ+5g/+eQ8Vms29gM39u/v5NSpXxAbm4QxFx4HtDB2bIpH227EquyxYzXkLSLBobAWn3lWEoun\nouJa99cdw8QJwHIyMjznlh2FRzqxWkfQ1PQdmprGMHv2LqZOnUBh4Rjsq8BttjAqK22MGPEMsAl7\nwE+btsujbVqVLSKhTGEtPjMC0V6TOxb4gKQkxypvX4aJ7QuyMjKKKS1d2v35iorR7N17B7CLwsJm\nHD3pZq5ds7gVRfG8r1Zli0goU1iLz/LyFnDy5C+prjZqcsMSYEf31/szTOzeE66t/ZCHH4aUFBsx\nMVW0ta3u/lpExA/Yvv3fe72fVmWLSChTWIvPLJYEJk26jepqx1B4VdX4Ad3LuSdcW/uhy2lZ8fHb\naWtzPCM19TZ/NF9EZMhSWIvPvJ07PdC5YeeecEYGLqdlRURYcV79nZbWNui2i4gMZQpr8Zn7udPJ\nyZvJy3t00Pd1HxKfMyee6GjNP4uI2CmsxWfuq8EnTLgViyWhuyBJZaWF5OSGfh8T6bk4bLGOmRQR\ncaKwHqYGcg5zT9ujPKuS9e+YyIEsDtM50iIynCish6mBnMPc0/aoG3VMpHNA19Z+QGXlasCic6RF\nJOQprIcJ957oxx/H0t+A7akH3FOP29+9X9cefBbGXuyv+9x+EZGhSmE9TLj3pJOTc+mpfndf3EN4\n40ajmIkxZ93Y3eMeSO+9N54V1GKv/1kVy0QktAUsrLds2cLrr7/OuHHjAMjJyWHevHmBepz0wT3o\nxo6dwuzZA1tx3VMIJybGU1fX3OMzB9v7de/BJyeXMWFCl1aMi0jIC2jPeuXKlaxcuTKQjxAfuQfd\ntGnXBtzL9TWE+1uvu69hc88580e1qExEhoWAhrXNZgvk7aUf/Fk729cQ7u8z+xo2V0lRERmuAhrW\nu3fvprCwkNtuu43169cTHx8fyMdJL/wZdL6GcH+feaNWlYuIDDVhtkF0f1euXMnly5c9Pp+Tk8Os\nWbOwWCyEhYXx4osvUldXR25u7qAaKwNXX29l9eoizp+PIzW1ma1bMxk71lxDyA8//Dtef91Y3Q02\nvva1Pezd+/VgN0tEJOgGFda+unTpEqtWreLAgQN9Xuu8QCnUuC/AupGys/NdCpdkZfl/X/Jg36+x\n0cratUddeuxmmpMO5s8v0EL53UDvN9QNh/frS8CGwevq6khMTATgrbfeIi0tLVCPEh84hpitQBFv\nvgnZ2W+YqvKX5qRFRLwLWFj/9Kc/5aOPPiI8PJybbrqJH/3oR4F6lPjAsSisCFhOa2sYhYUD3/vs\nbeW2L78diohI/wUsrPPy8gJ1axmADRvu4NSpzVRVTcJmG/wiLm8rtwsKVviruSIi4kQVzIaJzZvf\nu3685Wv0VrnM3mMuL4+goaGCcePSmDr1isdwuVZui4jcOArrYcIRrpnAHkaO7CAjA49tV44e8x5g\nA5WVYbz/vudwube91vX1VrKz9+skLBERP1NYDxOOcE0AlpOR4Qhf5/nnf/yjEyOA43DuOZeXjyI7\nO9+jHrh95faGDV9g1qxfcfHiOvpbC1zHXYqI9E5hPUz0VsjE9TSr3RjD5M04D5dfvnyGsrKnsQdx\ne/sOfvObh7vvkZ2dz8WLtzKQoXF/H/ghIhJqFNbDRG/bolznnxeRkPA8kycn09Cw+fqc9accPZqA\ncxCfOBHu5R4tDOQkL81/i4j0TmEtbvPPY0hPn8j27fe5XJOWthXnIIZ6L/e4D2OuO5bk5DLy8h4d\nwPN13KWIiDuFtfhU63vOnDiKil4D4oFm5syJ87hHTMxhzp4dSUqKtV8nYvnzkBERkVB0Q8qN9keo\nl5Qbqu/nSynQofx+vgjl9wvldwO931A3HN6vL+pZB0ioVfhSKVARkeBRWAeIKnyJiIi/hPd9iQyE\nVjiLiIi/KKwDJCXlnxirpkErnEVEZDA0DB4gWuGsymQiIv6isA6Q/i7ICsVgU2UyERH/UFibRCgG\nm+btRUT8Q3PWJhGKwaZ5exER/1DP2iRCseSm5u1FRPxDYW0SoRhsKqQiIuIfCmuTULCJiEhPNGct\nIiJicgprERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgpr\nERERk1NYi4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NY\ni4iImJzCWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicgprERERk1NYi4iImJzC\nWkRExOQU1iIiIiansBYRETE5hbWIiIjJKaxFRERMTmEtIiJicoMK68OHD7N48WJmzpzJBx984PK1\nbdu2kZGRwcKFC/nzn/88qEaKiIgMZ4MK67S0NLZs2cLs2bNdPl9eXk5RURGHDh1i+/btPPvss9hs\ntkE1VEREZLgaVFhPnTqVKVOmeARxcXExmZmZREZGMnnyZFJSUjh9+vSgGioiIjJcBWTOuqamhqSk\npO6PJ06cSE1NTSAeJSIiEvIi+7pg5cqVXL582ePzOTk5LFiwwOv3eBvyDgsLG0DzREREpM+w/vWv\nf93vm06aNImqqqruj6urq5kwYYJP35uYGN/v5w0ler+hLZTfL5TfDfR+Q12ov19f/DYM7tybXrBg\nAYcOHaK9vZ1PPvmECxcu8LnPfc5fjxIRERlWwmyDWKb99ttvs2nTJhobGxk9ejQzZszglVdeAYyt\nW/v27SMyMpKnn36au+66y2+NFhERGU4GFdYiIiISeKpgJiIiYnIKaxEREZNTWIuIiJicacN6x44d\nzJgxA6vVGuym+NXPf/5z7rvvPpYuXcrjjz9OXV1dsJvkV3l5eSxcuJCsrCzWrFlDS0tLsJvkN73V\nwh/Kjh8/zr333ss999zDyy+/HOzm+NXGjRuZO3cuS5YsCXZTAqK6upoVK1aQmZnJkiVL2LlzZ7Cb\n5Dft7e089NBDLF26lCVLlrBly5ZgNykgurq6uP/++1m1alWv15kyrKurq3n33XdJTk4OdlP87tvf\n/jb79++noKCA+fPnh9x/gHfddRcHDx6ksLCQlJQUtm3bFuwm+U1PtfCHsq6uLjZt2sSOHTv4wx/+\nwMGDBykvLw92s/zmgQceYMeOHcFuRsBERESwYcMGDh06xJ49e9i9e3fI/Pyio6PZuXMnBQUFFBQU\ncPz48ZAsW71z506mTZvW53WmDOvc3FzWrl0b7GYERGxsbPefW1tbCQ835Y9gwObOndv9TrNmzaK6\nujrILfKfnmrhD2WnT58mJSWFm266iaioKBYtWkRxcXGwm+U3X/ziFxk9enSwmxEwiYmJzJw5EzD+\n3zJt2jRqa2uD3Cr/GTlyJGD0sjs7O4PcGv+rrq6mpKSEhx56qM9r+6xgdqMdOXKEpKQkbrnllmA3\nJWBefPFFCgsLiY+PD6lhK3f79u1j0aJFwW6G9MJbHf/3338/iC2Sgbp48SJnzpwJqQJUXV1dPPDA\nA1y4cIFHHnkkpN4NHB3T5ubmPq8NSlj3VG/8qaeeYtu2bbz66qvdnxuKvZi+6qnn5OSQk5PDyy+/\nzG9/+1vWrFkThFYOnC/14rdu3UpUVNSQmyscSC38oWwo/v0ST1euXOHJJ59k48aNLqN3Q114eDgF\nBQW0tLSwevVqzp07x/Tp04PdLL84duwY48ePZ+bMmZw8ebLP64MS1j3VGz979iyXLl0iKysLm81G\nTU0Ny5Yt4/e//z3jxo27wa0cOF/rqS9evJgnnnhiyIV1X++Xn59PSUnJkBw1GEgt/KFs0qRJVFZW\ndn9cU1Pjcx1/MYfOzk6efPJJsrKy+MpXvhLs5gREXFwcX/rSl/jTn/4UMmH93nvvceTIEUpKSmhr\na+PKlSusXbuWvLw8r9ebasI0LS2Nd955h+LiYo4cOcLEiRPJz88fUkHdl4qKiu4/FxcXM3Xq1CC2\nxv+OHz/OK6+8wtatW4mOjg52cwImVHqkt99+OxcuXODSpUu0t7dz8OBB7r777mA3y69C5WfVk40b\nNzJ9+nS++c1vBrspftXQ0NA9PHz16lVOnDgRUv+//N73vsexY8coLi7mZz/7GXfeeWePQQ0mnLN2\nFhYWFnJ/0V544QXOnz9PeHg4ycnJPPvss8Fukl/9+Mc/pqOjg8ceewyAz3/+8/zwhz8MbqP8xLkW\n/qpVq1xq4Q9VERERPPPMMzz22GPYbDYefPBBn1amDhXf//73OXnyJFarlfnz57NmzRqWLVsW7Gb5\nzV//+lcOHDhAWloaS5cuJSwsjJycHObNmxfspg1aXV0d69evp6uri66uLjIzM0lPTw92s4JGtcFF\nRERMzlTD4CIiIuJJYS0iImJyCmsRERGTU1iLiIiYnMJaRETE5BTWIiIiJqewFhERMTmFtYiIiMn9\nPyQ+uNKCpR6MAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0xa813090\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "# Plot the Data (Optional)\n", + "\n", + "import matplotlib.pyplot as plt\n", + "\n", + "plt.scatter(inputs.numpy(), labels.numpy())\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JaFHyAG9nDET" + }, + "source": [ + "## Step 2: Define our TensorFlow variables\n", + "\n", + "We'll use Keras's object-oriented [`Dense`](https://www.tensorflow.org/api_docs/python/tf/contrib/keras/layers/Dense) layer to create our variables. In this case, we'll create a `Dense` layer with a single weight and bias.\n", + "\n", + "(**Note**: We're using the implementation of `Dense` found in `tf.layers.Dense` though the documentation link is for `tf.contrib.keras.layers.Dense`. When TensorFlow 1.4 is released, the documentation will also be in `tf.layers.Dense`) " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 34, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 22, + "status": "ok", + "timestamp": 1505502830753, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "z9r-ZeyrXu3A", + "outputId": "6230a7a3-29fe-4d08-f101-da80425bad82" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 4, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Create TensorFlow Variables using Keras's Dense layer.\n", + "\n", + "wb = tf.layers.Dense(units=1, use_bias=True)\n", + "\n", + "# We can access the underlying TensorFlow variables using wb.variables.\n", + "# However, the variables won't exist until the dimensions of the input\n", + "# tensors are known. Once the dimensions of the input tensors are known,\n", + "# Keras can create and initialize the variables. Until then, Keras will\n", + "# report the variables as an empty list: [].\n", + "\n", + "wb.variables" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "docKLUaonYG_" + }, + "source": [ + "## Step 3: Define our loss function\n", + "\n", + "Our loss function is the standard L2 loss (where we reduce the loss to its mean across its inputs)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "0_w8ZJSCtuY7" + }, + "outputs": [], + "source": [ + "def loss_fn(inputs, labels, wb):\n", + " \"\"\"Calculates the mean L2 loss for our linear model.\"\"\"\n", + " predictions = wb(inputs)\n", + " return tf.reduce_mean(tf.square(predictions - labels))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 34, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 24, + "status": "ok", + "timestamp": 1505502830875, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "RkNbXoXkpjVH", + "outputId": "c36fc98d-3a57-4074-901d-c10ae017ae3f" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "\u003ctf.Tensor: id=40, shape=(), dtype=float32, numpy=7.3549819\u003e" + ] + }, + "execution_count": 6, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# Test loss function (optional).\n", + "\n", + "loss_fn(inputs, labels, wb)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 51, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 57, + "status": "ok", + "timestamp": 1505502830981, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "K_7beXoHOU7t", + "outputId": "1ad0856a-02ec-4117-a6c0-b41030981d87" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "w: tf.Tensor([[ 1.56891453]], shape=(1, 1), dtype=float32)\n", + "b: tf.Tensor([ 0.], shape=(1,), dtype=float32)\n" + ] + } + ], + "source": [ + "# At this point, the variables exist, and can now be queried:\n", + "\n", + "w, b = wb.variables\n", + "print(\"w: \" + str(w.read_value()))\n", + "print(\"b: \" + str(b.read_value()))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YIlebeb_qYtC" + }, + "source": [ + "## Step 4: Create our gradients function using `implicit_value_and_gradients()`\n", + "\n", + "With a loss function defined, we can calculate gradients and apply them to our variables to update them.\n", + "\n", + "To calculate the gradients, we wrap our loss function using the `implicit_value_and_gradients()` function.\n", + "\n", + "`implicit_value_and_gradients()` returns a function that accepts the same inputs as the function passed in, and returns a tuple consisting of:\n", + "\n", + "1. the value returned by the function passed in (in this case, the loss calculated by `calculate_linear_model_loss()`), and\n", + "1. a list of tuples consisting of:\n", + " 1. The value of the gradient (a `tf.Tensor`) with respect to a given variable\n", + " 1. The corresponding variable (`tf.Variable`)\n", + "\n", + "Test it out below to get a feel for what it does. Notice how the first value of the returned tuple (the loss) is the same as the value returned in the cell above that tests our loss function." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "v1spZQ4NwW1U" + }, + "outputs": [], + "source": [ + "# Produce our gradients function. See description above for details about\n", + "# the returned function's signature.\n", + "\n", + "value_and_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 153, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 46, + "status": "ok", + "timestamp": 1505502831114, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "21WMcpsmFFLd", + "outputId": "f51b3171-33f5-4f87-8bf7-0be2dc8edc8a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Outputs of value_and_gradients_fn:\n", + "Loss: tf.Tensor(7.35498, shape=(), dtype=float32)\n", + "\n", + "Gradient: tf.Tensor([[-3.00773573]], shape=(1, 1), dtype=float32)\n", + "Variable: \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e\n", + "\n", + "Gradient: tf.Tensor([-4.06519032], shape=(1,), dtype=float32)\n", + "Variable: \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e\n" + ] + } + ], + "source": [ + "# Show outputs of value_and_gradients_fn.\n", + "\n", + "print(\"Outputs of value_and_gradients_fn:\")\n", + "\n", + "value, grads_and_vars = value_and_gradients_fn(inputs, labels, wb)\n", + "\n", + "print('Loss: {}'.format(value))\n", + "for (grad, var) in grads_and_vars:\n", + " print(\"\")\n", + " print('Gradient: {}\\nVariable: {}'.format(grad, var))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JVDWpL9VYWdP" + }, + "source": [ + "## Step 5: Create an optimizer\n", + "\n", + "We'll use a `GradientDescentOptimizer` to fit our model." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "DudNEebMKDWN" + }, + "outputs": [], + "source": [ + "optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.1)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "YBeJYxY8YaiO" + }, + "source": [ + "### Step 5a: Test Our Optimizer\n", + "\n", + "Now we have everything needed to start fitting our variables to the data!\n", + "\n", + "In the next cell, we'll demo these capabilities. We'll:\n", + "\n", + "1. Print the current values of `w` and `b`\n", + "1. Calculate the loss and gradients\n", + "1. Apply the gradients\n", + "1. Print out the new values of `w` and `b`\n", + "\n", + "You can run the cell multiple times. Each time, you should see the values of `w` and `b` get closer to their true values of 3 and 2." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 102, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 103, + "status": "ok", + "timestamp": 1505502831285, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "diDZfrMJM3OC", + "outputId": "d585fff0-ecb3-4e98-9b33-bbae07a95d8c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Values of w, b, BEFORE applying gradients:\n", + "(array([[ 1.56891453]], dtype=float32), array([ 0.], dtype=float32))\n", + "()\n", + "Values of w, b, AFTER applying gradients:\n", + "(array([[ 1.86968815]], dtype=float32), array([ 0.40651903], dtype=float32))\n" + ] + } + ], + "source": [ + "# Test the optimizer.\n", + "\n", + "print(\"Values of w, b, BEFORE applying gradients:\")\n", + "w, b = wb.variables\n", + "print(w.read_value().numpy(), b.read_value().numpy())\n", + "print()\n", + "\n", + "# Calculate the gradients:\n", + "empirical_loss, gradients_and_variables = value_and_gradients_fn(\n", + " inputs, labels, wb)\n", + "optimizer.apply_gradients(gradients_and_variables)\n", + "\n", + "print(\"Values of w, b, AFTER applying gradients:\")\n", + "print(w.read_value().numpy(), b.read_value().numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "61TgeLVlKEQp" + }, + "source": [ + "## Step 6: Create a training loop\n", + "\n", + "Of course, now we can simply turn all of this code into a self-standing training loop. We'll also capture our loss and approximations of `w` and `b` and plot them over time." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 397, + "output_extras": [ + { + "item_id": 1 + }, + { + "item_id": 2 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 225, + "status": "ok", + "timestamp": 1505502831550, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "VukGe-huNaJ4", + "outputId": "f0a8d665-1910-477c-d8ab-c94ccdc4afcd" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[2.111051321029663, 2.3047544956207275, 2.4602210521698, 2.5850086212158203, 2.6851789951324463, 2.7655951976776123, 2.830157995223999, 2.8819968700408936, 2.9236228466033936, 2.9570505619049072]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAd0AAAFXCAYAAADnFpTQAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAIABJREFUeJzs3Xd4FFUbBfAzu+m9koSShBQCSC+igIAgRRGkChJEiggo\nHURAEBQBQeADRcWCha50ULFLk6IivYRQQwskhPS6O/P9sckmm4Rkk2x2difn9zz7bLuZvC8JHO7M\n7FxBkiQJREREVOlUchdARERUVTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMjArdlJQU\njB8/Hk8//TS6d++OkydPVnZdREREiiMY8znd6dOno2XLlujbty80Gg0yMzPh4uJijvqIiIgUo9TQ\nTU1NRa9evfDbb7+ZqyYiIiJFKnX38s2bN+Hp6YkZM2agd+/emD17NjIzM81RGxERkaKUGroajQbn\nzp3DoEGDsH37djg4OOCzzz4zR21ERESKUmro+vv7w9/fHw0bNgQAdO3aFefOnSvxa3g5ZyIioqJs\nShvg4+ODgIAAXL16FbVr18aRI0cQGhpa4tcIgoC4uBSTFSkHX19Xq+8BUEYfSugBYB+WRAk9AMro\nQwk9ALo+jFFq6ALArFmzMHXqVGg0GtSqVQsLFy6sUHFERERVkVGhW7duXWzdurWyayEiIlI0XpGK\niIjITBi6REREZsLQJSIiMhOGLhERkZkwdImIiMyEoUtERCbRuXM7uUuweAxdIiIyCUEQ5C7B4hn1\nOV0iIqKy+OijFTh69BAEQYUhQ4ajU6fOuH8/HnPmzER6ehq0Wi2mTJmOJ59sgwUL3kZU1HkAArp3\n74nnn39B7vIrDUOXiEhh5s6dhd27d5h0mz169MLcue8aNXbv3t9x+XI01qz5Fg8eJODll4egadNm\n+PXXn9Cq1eN48cVhkCQJmZmZOH/+POLi7uGbbzYBANLSUk1at6Xh7mUiIjKp06dP4qmnugIAPD29\n0LRpc5w/fw716j2CH37Yha+++hyXLkXD0dERtWrVwp07t7F8+RIcPXoYTk7OMldfuTjTJSJSmLlz\n3zV6VloZCq80l/e8ceOm+Oijz3H48EEsWDAXAwcOxuDBA/D11xtx9Ohh7Ny5DX/88StmzHhLjrLN\ngjNdIiIyifxwbYbff/8VoijiwYMHOHXqBOrXfwSxsbHw8PDEs8/2wrPP9sLFixeQmJgIUdSiffsn\n8fLLoxEdHSVzF5WLM10iIjKJvLOX27d/EmfPnsbQoS9AEFR49dXx8PT0wp4932PjxrWwsbGBk5Mz\nZs16G7GxsXj99TcgSSIEQcDo0eNk7qJyCVIlrThv7esjKmmNR2vvQwk9AOzDkiihB0AZfSihB8D4\n9XS5e5mIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIismjHjx/DmTOn9M93\n7NiKn3/+0STbXrv2K5Nsx1gMXSIismjHjx/D6dP5odurV1907fqMSba9Zo15Q5dXpCIiogrbsGEN\n7O3t0bfvAHzwwVJcvnwJK1Z8gmPH/sGPP+7C7NnzDMZHRV3Ahx8ug0aTDWdnN7z55hx4eXlj8+ZN\n2LlzG2xsbBAcXBujR4/Fzp1boVbb4Ndf92DixNfx779/w8nJCQMHDsa4caNQp04ETp48gczMTMya\nNRdr136FK1cuo2PHzhg5cgwAYMaMqYiLu4fs7Cz07/8CevTohVWrViI7OwvDh0eidu0QzJ49D7/8\nsgebN2+CVqtB/foNMGXKdJOuE8zQJSJSGOe5s2Bv4qX9snr0QloJiyg0btwM3367Hn37DkBU1AXk\n5ORAq9Xi1KkTaNy4mcFYjUaD5csX4733liEsrBY2bdqGTz/9CDNmvIX167/Bli27YWNjg7S0VDg7\nu+C55/rqQxYA/v33b4Pt2dra4Ysv1mDz5k2YPn0KvvpqPVxcXDFgQC8MGBAJNzc3zJw5B66ursjK\nysLIkUPQvn1HjB49Ftu2bcaXX64HAFy/fg2///4LVq36Emq1GkuXLsIvv+wx2awaYOgSEZEJRETU\nRVTUeaSnp8PW1hYREXVx/vw5nDx5HJMmTTMYGxNzHVeuXMakSa9BrVYhO1sDHx9fAEBYWDjmzn0T\n7dp1wBNPdDDqe7dt2w4AEBoahpCQUHh6egEAqlevgXv37sLNzQ3ffbcBBw7sAwDcu3cPN2/GoH79\nBgYrIv3779+4eDEKI0cOgSRJyM7OhpeXV0X/aAwwdImIFCZt7rslzkorg42NDfz9A/Djj7vQsGFj\nhIWF4/jxf3H79i0EBQUXGi0hJCQUn3zyZZFrL7///gqcOPEfDh7cjzVrvsSaNd+W+r1tbe0A6BZc\nsLW11b8uCAK0Wi2OHz+G//77F5999jXs7OwwbtwoZGdnF7MlCd26dceoUa+V40/AODyRioiITKJx\n46bYuHEdmjRphkaNmmDHjq0ID69TZFxgYDAePEjEmTOnAeh2N1+9egUAcPduLJo2bY4xY8YhLS0N\nGRnpcHJyQlpaWrnrSktLhaurK+zs7HD9+jWcPXtG/56trS20Wi0AoHnzR7F37+948OABACA5ORmx\nsbHl/r7F4UyXiIhMonHjpli79is0aNAQ9vYOsLe3L3I8F9DNit99dxGWL38fy5cvQnZ2Dp5//gXU\nqhWId96ZnRuwEvr3HwhnZxe0adMOs2a9gb/+2o+JE183OLGppJOc8t5r1ao1duzYisGDn0dgYBAa\nNGioH9OzZ2+89NJARETUxezZ8/Dyy2MwefJrEEUJtra2mDx5Gvz9/U32Z8Sl/R5CSctNWXsfSugB\nYB+WRAk9AMroQwk9AFzaj4iIyOIwdImIiMyEoUtERGQmDF0iIiIzYegSERGZCUOXiIjITBi6RERk\ndt99txFZWVlyl2F2DF0iIjK7zZs3Iisrs9j3RFE0czXmw9AlIqIK27BhDbZu1V0n+YMPlmLCBN2S\neseO/YN582YbjN2yZRPi4+MwbtxovPTSSwCAzp3bYeXK5Rg2bBDOnDmF/v17Ijk5CQBw4cJ5jBs3\nCgCQmZmJhQvfwciRL2H48ME4eHC/uVo0CV4GkohIgbyaNyj29YRjZ4p9vazjCyvL0n79+g3Et99u\nxIcfforQ0BqIi0tBZmYGGjRoiLFjJ+aOMry8Y94lHb/5ZjWaN38UM2a8hdTUVIwcOQQtWz4Ke3sH\no+qUG0OXiIgqrCxL++lIuTcdtVqN9u07Fnq/qH/+OYpDhw5g48Y1AHSLJdy9G4vAwGCT9VKZGLpE\nRApk7Ay1vOMLK9vSfkXZ2dkbLF6gVqshirrgzc7OP+FKkiS8++5i1KoVWKF65cJjukREZBLGLu0H\nAE5OzgbL9RVeeycgoDqios4DAPbt+0P/+qOPPoYtWzbpn0dHR5myhUpn1Ey3Y8eOcHFxgUqlgo2N\nDbZs2VLZdRERkZUxdmk/AOjZsxemTh2PgAB/LFmyssgSfUOHjsR7770DFxcXNG3avMDrL+ODD5bi\npZcGAgD8/QOwaNH/Kq8pEzNqab9OnTph27ZtcHd3N2qjFy9ehKdnQIWLk5OSlpuy9j6U0APAPiyJ\nEnoAlNGHEnoATLy0nyRJZfrc1IABA5CTk2P0eCIioqrAqNAVBAEjRoxA37598d1335U6/sSJE/jw\nQ+uZ7hMREZmDUcd0N23aBF9fXyQkJGDYsGEICQlBixYtHjq+Ro0aWLp0Ebp164769R8xWbFERETW\nzKhjugWtXLkSzs7OGDZs2EPH/PDDD3j22WfRvHlzHDlyBDY2/GQSERFRqWmYkZEBURTh7OyM9PR0\nHDx4EGPHji3xa7p3747nn38B3323EXPnvosJE6aYrGBzUdLBfWvvQwk9AOzDkiihB0AZfSihB8D4\nE6lKDd34+HiMHTsWgiBAq9WiR48eaNu2bakbfvfd97Bv3594//2F6NatOyIi6hpVEBERkVKVeiJV\nrVq1sHPnTuzYsQO7d+/GK6+8YtSGPTw88f77y5GdnY0JE8ZAo9FUuFgiIrJMsbF3MGTIAJNuMzr6\nIg4f/kv//ODB/Vi//huTbFuupQUr9YpU3bo9g759n8d//x3DqlUfVea3IiIimRW+wEVFXbp0EUeO\n5Idu27btEBn5kkm2XdLSgpWp0s9wmj9/Efbv34tFi95F165PP/SSYEREZN00Gg3eeWc2Ll68gNq1\nQzFr1tuwt7c3GHPr1k0sW7YYSUmJcHBwwHvvLYCLiw/++OM3fP3151Cr1XB2dsHy5R/jiy9WITs7\nG6dPn8TgwcOQlZWJCxfOYdKkaViw4G3Y2dkjOjoKiYkPMGPGW9iz53ucPXsa9es3wMyZcwAAS5a8\nh6ioc8jKykKHDp0wfPgrBksLenh4YMWKT/D330fw5ZefIScnBzVq1MTMmXPg4GD6lYsqPXS9vLyx\nePH/MGxYJCZMeBW7d/8MtVpd2d+WiKjKmjvXHrt3m/af9x49NJg7t+TdsTEx1zFjxhw0aNAQCxe+\ng+3bN2PgwMEGYxYvXoBp02aiRo2aOHfuDObOnYslS1bim2++wLJlH8HHxwdpaamwsbHByy+PRlTU\neUyc+DoAYM+e7w1m06mpKfj0069w8OA+vPHGJKxa9RVq1w7BiBEv4tKlaISFhWPUqNfg6uoKURQx\nYcIYXLlyyWBpQTc3NyQlJWLNmi+xYsXHsLd3wPr132DTpnUYOvRlk/4ZAmZaZah79x7o1asPduzY\nhs8//wSjR5d89jMREVkfPz9/NGjQEADQtesz2LLlW4PQzcjIwJkzJzF79hsFFjjQ3Tds2Bjz589B\nx46d0b79k0Z9vzZtngAAhISEwcvLG7VrhwAAatcOQWzsbYSFheP333/Grl07oNVqkZBwH1evXkVI\nSBgKLi149uwZXLt2BWPGjIAkSdBoNGjQoFHF/0CKYbYP0C5YsAQHD+7HggXvoEuXbrlNExGRqc2d\nm1XqrLQyFD6mW/gQrySJcHV1w5dfrte/lveRoalTZ+D8+bM4dOggRox4EatXryv1+9nZ2QEAVCqV\n/nHec61Wizt3bmPTpvVYvXotnJ1dsGDB2wbLBObXJaFly8cwZ867ZWm3XMy2tJ+Pjw/ee28pMjMz\nMWHCa2W6ljMREVm+2Ng7OHtWty7vr7/+jEaNmhi87+TkjICA6vjzz9/0r124cAGA7lhvvXqPYMSI\nUfDw8MS9e3fh5ORksPxfSYq7zlNaWhocHR3h5OSMhIT7OHLkkEEtedt+5JGGOH36JG7dugkAyMrK\nxI0bMWXo3HhmvVRUz5690aPHduzevQOrV3+KkSPHmPPbExFRJQoKCsa2bd9h4cK3ERwcgl69+hUZ\nM2fOu3j//YX45psvodVq0LNnD/Tv/yI+/ngFbt68AQBo3rwlwsLCUa2aH9at+xrDh0di8OCHXwUR\nKP7M6bCwcISHRyAysh+qVfNDo0aN9e/lLS3o4+OLFSs+wcyZczB37kxkZ+dAEASMHDkGtWoFVvBP\npJg6y3oZSGM97AojcXFxeOKJlsjMzMSffx7S74O3NEq6Soq196GEHgD2YUmU0AOgjD6U0ANg4qX9\nTMnX1xcLFy5Beno6Jk0ay93MRERUZZg9dAGgV6++ePrpZ3Ho0EF8/fVqOUogIiIyO1lCVxAELF78\nP3h4eOCdd97C9evX5CiDiIjIrGQJXQDw8/PD/PmLkZ6ehsmTxxV75hkREZGSyBa6ANCv3wB06dIN\nBw7sw5o1X8lZChERUaWTNXQFQcCSJSvg7u6BuXNnVdrnooiIiCyBrKELAP7+AZg3byHS0lK5m5mI\nyEoZu7Tfnj3f4/79eDNUZJlkD10AGDBgEDp16ox9+/7Ehg1r5S6HiIjKwZil/X78cTfi4uKKfa8q\nfITUIkJXEAQsXfoBXF3d8NZbM3H79i25SyIiojLKW9pv8OD+mD17epFF4vfu/R0XLpzHvHmzMXx4\nJLKystCxY0d88smHGDHiRfz5528YN24UoqJ0l4ZMSkpE//49AegC+eOPV2DkyJcwdOgg7Nq13ez9\nmYJFhC4AVK9eA++8swApKcmYMmU8dzMTEVVA8+bOxd5MNb44MTHX0afP81i3bjOcnJywfftmg/c7\ndOiEevXqY86cd/Hll+v1a+26u3tg9eq16NSpSzFb1c2ev/9+J1xcXPH559/g88+/wa5d2xEbe6dM\n9VkCiwldABg06EV06NARv//+K779doPc5RARURkUXtrv1KmTRcZIkoTCc6pOnTqXuu2//z6Cn376\nAcOGDcIrr7yE5OQkqzz51qwLHpRGEAQsW/Yh2rV7DLNnz0CHDh3h7x8gd1lERFbn2DHjVucp7/ji\nlLa038M4OjrqH6vVakiS7thudnZ2gVESJk16HS1bPlbRMmVlUTNdAKhZsxbmzJmHpKRETJ06gbuZ\niYisRGlL+wGAs7Mz0tJSH7qNgIAauHDhHAAYLAH46KOPY9u2LdBoNACAGzdikJWVacryzcLiQhcA\nhgwZhieeaI9ffvkJW7Z8K3c5RERkhLyl/QYP7o+UlORil/Z7+ulnsWTJQv2JVIVnxy+8EInt27di\n+PDBSE5O1r/eo0cvBAfXxogRgzFkyAAsWbIQWq220nsyNbMv7WesmJjraNfuMdjZ2eLAgX/g5+dn\nosqMo6Tlpqy9DyX0ALAPS6KEHgBl9KGEHgALXtrPWIGBQZg9+20kJiZi2rRJ3M1MRERWz2JDFwCG\nDXsZrVu3xZ4932PHjq1yl0NERFQhFh26KpUK//vfSjg5OWHGjKm4d++e3CURERGVm0WHLgDUrh2C\nN9+cg4SEBMyYMVXucoiIiMrN4kMXAEaMGIVWrR7H7t07rPbSX0RERFYRuiqVCitWfAQHBwdMnz4F\n8fFVd4UKIiKyXlYRugAQEhKGGTPeQnx8PGbO5G5mIiKyPlYTugDwyitj0KLFo9ixYxt++GG33OUQ\nERGViVWFrlqtxooVH8Pe3h7Tpk1CQsJ9uUsiIiIymlWFLgCEh9fBG2/MQlzcPbz55htyl0NERGQ0\nqwtdABgzZiyaNWuOrVu/w08//Sh3OUREREaxytDV7Wb+BHZ2dnj99YlITHwgd0lERESlssrQBYCI\niLp4/fUZuHs3FrNnz5C7HCIiolJZbegCwGuvTUDjxk3x7bcb8OuvP8ldDhERUYmsOnRtbGzwwQef\nwNbWFlOnTkRSUqLcJRERET2UVYcuANSrVx+TJ0/DnTu3MWfOm3KXQ0RE9FBWH7oAMH78ZDRo0Agb\nNqzFH3/8Jnc5RERExVJE6Nra2uKDDz6BjY0NJk8eh5SUZLlLIiIiKkIRoQsADRo0xMSJU3H79i3M\nnTtb7nKIiIiKUEzoAsDEiVNRv34DrF37Ffbt+1PucoiIiAwYHbqiKKJ3794YPXp0ZdZTIXZ2dvjg\ng4+hVqsxefI4pKamyF0SERGRntGhu2bNGoSGhlZmLSbRqFETjB8/CTduxGDevDlyl0NERKRnVOjG\nxsZi37596N+/f2XXYxKTJ7+BunXr4auvvsDBg/vlLoeIiAiAkaG7YMECTJs2DYIgVHY9JmFvb48V\nKz6GSqXCxIljkZaWJndJREREsCltwN69e+Hj44N69erh6NGjRm/Y19e1QoVVVJcuHTBt2jS89957\nWLZsAT744IMyb0PuHkxFCX0ooQeAfVgSJfQAKKMPJfRgLEGSJKmkAcuWLcOuXbugVquRlZWFtLQ0\ndO7cGYsXLy5xw3Fx8p/ElJmZiaeeegIXL0Zh5849ePzxNkZ/ra+vq0X0UFFK6EMJPQDsw5IooQdA\nGX0ooQfA+P84lLp7efLkydi7dy9+//13LFu2DK1atSo1cC2Fg4MDli//CCqVChMmvIr09HS5SyIi\noipMUZ/TLU6LFo9i9OixuHbtKhYunCd3OUREVIWVKXQfffRRrFq1qrJqqTRvvPEmQkPD8NlnH+Po\n0SNyl0NERFWU4me6AODo6Ijlyz8GAEyc+CoyMjJkroiIiKqiKhG6ANCq1WN45ZUxuHz5EhYtmi93\nOUREVAVVmdAFgBkz3kJwcG2sWrUS//77t9zlEBFRFVOlQtfJyQkrVnwMURQxYcKryMzMlLskIiKq\nQqpU6ALA44+3wcsvj0J09EUsWfKe3OUQEVEVUuVCFwDefHMuAgODsXLlchw/fkzucoiIqIqokqHr\n7OyM5ctX6nczZ2VlyV0SERFVAVUydAGgbdt2GDp0BC5cOI///c86rrBFRETWrcqGLgC89dY7qFUr\nECtWLMOpUyfkLoeIiBSuSoeui4srli37EFqtFuPHv4rs7Gy5SyIiIgWr0qELAO3bP4kXXxyGc+fO\nYPnyJXKXQ0REClblQxcA5s6dhxo1amL58iU4c+a03OUQEZFCMXQBuLq6YenSD6DRaDB+/Bjk5OTI\nXRIRESkQQzdXx45PYdCgF3HmzCl8+OH/5C6HiIgUiKFbwNtvz4e/fwCWLl2E06e5m5mIiEyLoVuA\nu7sHli5dgZycHAwdOhSpqalyl0RERArC0C2kc+duiIwcgv/++w8DBvRGcnKS3CUREZFCMHSL8f77\nyzFo0CD8889R9O3bEwkJ9+UuiYiIFIChWwwbGxusWbMGgwa9iJMnj6N372cRFxcnd1lERGTlGLoP\noVarsWzZhxg+fCTOnz+LXr2exp07t+Uui4iIrBhDtwQqlQoLFy7Bq6+OR3T0RfTs2Q03bsTIXRYR\nEVkphm4pBEHAnDnzMGXKG7h+/Rp69uyGK1cuy10WERFZIYauEQRBwBtvvIlZs+bi1q2beO65pxEV\ndUHusoiIyMowdMtg/PjJmD9/Ee7ejUWvXk/j9OlTcpdERERWhKFbRiNHjsGSJSuQkJCAPn2exfHj\nx+QuiYiIrARDtxyGDBmGDz9chZSUZPTt2xNHjhyWuyQiIrICDN1yev75F/DZZ18hMzMDAwf2xoED\n++QuiYiILBxDtwJ69uyNr75aD41Gg0GD+uG3336WuyQiIrJgDN0K6tr1aaxb9x1UKhVeemkQfvhh\nt9wlERGRhWLomkCHDh2xceNW2NnZ4+WXh2Dbts1yl0RERBaIoWsirVu3xebNO+Ds7IIxY17Ghg1r\n5S6JiIgsDEPXhFq0eBTbtu2Gp6cnJk58DatXfyZ3SUREZEEYuibWqFETbN/+I3x9q2HGjKn4+OMP\n5S6JiIgsBEO3EtSrVx87d+5BQEB1zJ37JpYuXQRJkuQui4iIZMbQrSRhYeHYuXMPAgODsGjRfCxY\n8A6Dl4ioimPoVqLg4NrYuXMPQkJCsWLFUsyePZ3BS0RUhTF0K1mNGjWxc+dPqFu3Hj777BNMnToR\noijKXRYREcmAoWsGfn5+2L79RzRo0Ahr136FceNGQ6PRyF0WERGZGUPXTLy9vbFt2240b94Cmzdv\nwujRI5CTkyN3WUREZEYMXTPy8PDE5s078fjjbbBr13YMHz4YmZmZcpdFRERmwtA1MxcXV2zcuBXt\n2z+Jn3/egyFDBiI9PV3usoiIyAwYujJwcnLC2rXfokuXbti79w8MGtQPqakpcpdFRESVrNTQzc7O\nRv/+/dGrVy/06NEDK1euNEddiufg4IAvv1yHHj164dChg+jfvxeSkhLlLouIiCqRTWkD7OzssGbN\nGjg6OkKr1eKFF15Au3bt0KhRI3PUp2h2dnb49NMvYW9vjy1bvkWfPj3w3Xc74O3tLXdpRERUCYza\nvezo6AhAN+vlR11My8bGBitXfooXXxyK06dPok+f7rh7967cZRERUSUodaYLAKIook+fPoiJiUFk\nZGTps9zgYHiJRa+8lHDsTLHDvZo3KPZ1WcerhCI9VGY9XwFwGDkan3++Cr16PY2tW3ejevUaFd9+\ngT6s6s+/oNweLKaeco5HzHWLqofjOd4SxisiL4CH/v0uzKjQValU2LFjB1JTU/Hqq6/i0qVLCAsL\nK/Fr1CqhyGu+vq4P+QZFx1rC+MI9VHY9n376Mby83LFo0SL07v0M/vjjDwQHB1d4+3l9yP3nWZHx\napVgUfWUZ/xDv8ZK6i843uBrLaCe8ozXP7eQeso7vrh/a+Wsp8zjoYy8MJYglfFiwCtXroSzszOG\nDRtW4ri4OOs+G9fX11WWHiRJwtKli7B48QJUr14D27btRkhIyf/BKYlcfZiSEnoA2IclUUIPgDL6\nsPgeRBHIzISQmQEh9x4ZmRCyMiFkZgKZGRAyMuE+dJBRmyt1ppuQkABbW1u4uroiMzMThw8fxiuv\nvFLhPqh4giBg6tTpcHBwxDvvzEbPnk9jy5ZdqFu3ntylERHJy8gAzHtfN7bg89z3szLzt5M7Hpm5\n28nIHZf3fna2cbWZKnTj4uIwffp0iKIIURTxzDPPoH379sYVQeU2duwEODo6YMaM19G79zP47rsd\naNiwsdxlEREVJUlAdjaE9DQI6em5tzT9PdLTIaQ95D1JA9fEFMMAzMrSPy9XAJa1fEEAHB0hOThA\ncnCE5OICyccXkoM9JAdHIO91BwdIjo6Avb3hcwcHuBj5vUoN3YiICGzfvr2CLVF5jBgxCg4Ojpg8\neRz69OmBTZu2onnzlnKXRUTWSJKAjIwioWd4nw4UfC2taEgWHZf7ulZb7tIcCpZZXAB6+0BydCga\ngA4ORQPRwQGSfe57jo4FxjoCDvaGz/O2aWsLCGU7NluYyUKX5BUZOQQODg4YO3YU+vV7Dhs2bMbj\nj7eRuywiqkyiqAuylBQIqakQUpINH6emQJWaCkg5cI5/UCQQ9Y/T8h8jIx2CCdbzllQqSE7OkJyc\nACcniN4+kJyc9K9JTk6QnAs8dnIGDN43fM+rpi/i00VdANo7AHZ2FQ5AS8bQtQJ9+z4POzt7jB49\nHAMH9sGaNZvQvv2TcpdFRAVJku64YEpKbiimFB+aqbrHKv17KbrX8h6npEBISzU6IJ2KK8XWVh9u\nors7pIDqucH38PArGJYlhSTs7U0bir6ukCz5RCoTY+haiR49noODw3oMH/4iBg9+HqtXr0GXLk/L\nXRaR9cvJyZ095oeeKi0lPwALhmZaam5gFgzRlPyvL+fFgyQ7O0iurpBcXCEGBUN0dc197gLJxS3/\nsasrJFfc0+i4AAAgAElEQVQ3iC4ukFxc4FHTDwlZAJxzw9HRUReMtram/TMik2HoWpHOnbth/frN\nGDJkIIYOjcSnn36JHj16yV0WkbxEEUJyEoTERKiSEiEkJkJISoQqMbH415ISgdRkeCcl6YKynMtr\nSmo1JBddOIoB1SE560JRdHXLD0gXV/2Y/OB0g+iS/1hycdHNHsvD1xXaKjRLVAKGrpVp164DNm3a\nhkGD+mPkyKH48MNV6N9/oNxlEVWMkcGpSnxQJECF5KQyHauUnJwAd3eInl6QagUWmUnqQzMvLAuG\npqsrRGfdPRwdFX3skSoHQ9cKPfZYa2zZshMDBvTB2LGjkJWVhcGDX5K7LKrqSgzOB/qQzAvSigan\n6O4BsXp1iPXqQ/LwgOTuAbHQveThAdHDE5KHJ0R3D0ju7oC9PXx9XfGAM0SSAUPXSjVr1gLbtn2P\n559/DpMnj0NmZgZefnm03GWRUmRkQBUfB9X9eKjux0OIj4fq/n2o7scDmalwi40zTXB6eEKsXgNi\n/UfyQ1Iflh6FXvPUvwc7u0psnqjyMHStWMOGjbBjxx707dsDM2dOQ0ZGJsaNmyh3WWSJ0tL0AaoP\n0fgCz+/H54bsfaji43UXLShB3hFIyckZoocHg5PISAxdKxcRURe7du1B3749MW/eW8jISMfrr8+A\nwGNNyiVJhiEaHwchNyzzn+cFqm52KqSnl75Ze3uI3j7QhIVD8vaG6O2ju/n4QPLxzX3uDc/QWojX\n2up21TI4icqEoasAISFh2LlTN+NdsuQ9ZGZmYvbstxm81kKSdB9FiYszDMr4eMNdvPfv658bc8at\n5OCgC9HwiEIh6gvJx0cfonnPJWcX404MqmKfqyQyJYauQgQGBmHXrp/Qt28PrFy5HBkZ6Zg/f7Hc\nZVVdkqQ7eSg2Fqo7t6G6GwukJcL5+q1Cx0lzH2dllb5JR0eIPr7Q1K2nuwpQboDqZ6O5AZoXrnB2\n5tm1RBaGoasgAQHVsWPHHvTv/xxWr/4MWVlZ+Prr1XKXpTzp6VDF3oH6bm6g6oP1DtR37kAVeweq\nu7HFzkYLXj1IcnKG6OMDTf1HdCFaIDAfGqJEZNUYugpTrVo1bN/+PQYM6IN1677BvXt3MG/eYtSu\nHSJ3aZZPo4Hq3l1daBYIT/Wd27rHsXd0AZuU+NBNSCoVRN9qutmof4D+pg2oDrewIDywdc4PUafi\nLuBHRErG0FUgLy9vbN26C6+8Mgy//PIL9u/fjwkTpmDs2ImwL++Vb6yZJEF4kKAL0oKz0dhYqGIL\nzFTj7pX4kRfRwwNiQAA0TZvlBmkARL8AiAHVIfr76+59fAGbh/y18nWFhsdCiao0hq5Cubm5Y+PG\nrfjzzz2YMGEiFi2ajy1bvsWiRcvQrl0HucsznbQ0qO8WmJnmBqsqNm+GGgvV3TslHjOVHBwg+gcg\np9XjEIsJUq2fP0T/AN0ViIiIKoChq2CCIGDAgAFo0aINFi2aj9WrP0O/fj3Rp08/vP32Qvj5+cld\n4sPlnoikvhEDJMXB4eKV/BlqgWBVJSc9fBMqFUQ/f90xU78AXaDm7uoV/fz1wSq5e/CEIyIyC4Zu\nFeDm5o758xdjwIBBmDZtErZt24Jff/0FM2fOxtChL0OtVstTWFoa1DdioI65BlXMdaivX4c6RndT\nxVyHKiVZP9S10JeKnp4Qa9SEpnkLaP0Dip2hij6+gFy9EREVg6FbhTRq1AQ//PAb1q79GvPnv40Z\nM17Hpk0b8P77/0OTJs1M/w2zs6G6dVMfpLowvaZ7fP06VPFxxX6Z5OQEbWAQcgJbQxsYBKd6dZDs\n6gWtf26g+gcADg6mr5eIqJIxdKsYtVqNoUNH4JlneuDtt2dh8+ZN6Nr1SQwdOgIzZ74Fd3cP4zcm\nirqPzsRch+r6NYNZqjrmOlR3bkMQxSJfJtnaQluzFjSPNIA2MAjawCCIuffawGBIPj4Gu3udfF2R\nxROQiEgBGLpVVLVq1fDRR59h0KAXMW3aJHz11Rf4/vtdePvt+ejb93nd1awkCUJCAtS5s1OVfvdv\n7u7gmzcgZGcX2bYkCBADqiPn0ccKhGkQxKBg3b1/AHf7ElGVxNCt4to2boIDH32OXz/7GCd3bkPW\nqyNxcdZ0NPX0hGNsLFRpqcV+nejjkztTDS4UrEHQ1qhV/kW5iYgUjKGrdFlZUF+OLjBLzdv9mzt7\nTUgAAAzOvQEAEu4jOeE+Yn184dGmLVA7JDdYdTNVba1AwMVFro6IiKwWQ1cJJAlCfDxsoqOgvhgF\ndXQUbC5GQX0pGrh9C17FXPBBsreHtlYgNI2b5odpkC5Qf4m+iCnz38btO7cReOEC3hs6Ak891VWG\nxoiIlIWha01EEapbN2Fz8QLUFy/mh2t0FFQPHhQZrq1eA2jfHhkBNQ1OVBKDgiBW8wNUqmK/Taem\nzXHwmR5YunQRPv30Iwwa1B/du/fEu+++hxo1alZ2l0REisXQtUQ5OVBfvQL1xagCs9eLsLl0sci6\nqJJKBW1wbeS0ehza8Aho6kRAWycC2vA6kFxc4evritRynPnr4uKCOXPm4fnnX8C0aZPwww+78Oef\nv2PatJkYOXI0bG1tTdUtEVGVwdCVU1oabC5dzA/V3Fmr+uoVCBqNwVDJwQHa0HBo6tTJD9fwCGhD\nQiv1pKV69epj5849+PbbDXj77VmYO/dNfPvtBixe/D+0avVYpX1fIiIlYuiagXD/ftHjrdEXob55\no8hY0d0DmibN8kO1Th1owiMg1gqU7WM2KpUKL7wwGF27Po13352Ldeu+QY8eXRAZOQSzZ78NLy9v\nWeoiIrI2DF1TkSSobt8qsEv4ItQXL8AmOgqq+/eLDNf6+SP7iQ76UNXWiYAmPAJStWoWex1gLy9v\nLFv2IQYMiMS0aZOwfv0a7NnzPd56ax4GDoyE6iHHiImISIehW1YaDdTXrhaatUZBHR1d5DOtkkoF\nMTAIWc1bFtglXAfaOhGQ3NxlaqDiWrV6DL/9th9ffPEpFi2aj4kTX8OGDWuxePH/UL/+I3KXR0Rk\nsRi6D5OeDpvTJwuEq+5sYfWVyxBycgyGSnZ20IaGI7tAqGrCI6ANDVPsNYJtbW0xZsxYPPdcb8ya\nNR3ff78TnTq1xahRr2Hq1Olw4ed4iYiKYOgCEFKSYXPqJGxOnoDNqeOwOXkCuHIZnoU+3yq6uELT\nsBG0deoW2CVcB2JQcJW9rGH16jXw5Zdr8dtvP2P69Nfx8ccfYMeOrZg/fzGeeeZZ3eUkiYgIQBUM\nXSE5CTanT+kC9uR/uvsrlw3GiG7uQLt2yKgdVuCEpgjdNYMZIsV66qmuOHCgHVasWIIPP1yOYcMi\n0blzVyxY8D6CgoLlLo+IyCIoOnSF5KQiM9giAevugewn2kPTqAk0TZoip1ETiMG14VvNrVyfb63K\nHB0dMX36bPTtOwBvvDEZv/76Mw4e3I9Jk17Hq6+Oh52dndwlEhHJSjGhW6aAbdwUmsZN9AHL2atp\nhYfXwdatu7F163eYM+dNLFjwDjZv3oRFi5ahbdt2cpdHRCQbqwzdIgF74jhsrl4xGKML2A7QNG7C\ngJWBIAjo128AOnfuioUL5+Grr75Anz7Pol+/AZg7dz6qVasmd4lERGZn8aGrD9gTx/NnsKUFbOOm\nupObGLCyc3f3wHvvLcWAAYMwbdpkbNnyLX755Se8+eYcDBkyDOoqegIaEVVNFhW6QlJi0V3EJQRs\nTpOm0DRqwoC1Ak2bNsdPP/2Br79ejQUL3sEbb0zGpk3r8P77y9GoURO5yyMiMgvZQteogPXwQHa7\nJ3Nnr00YsFZOrVZjxIhX8OyzPTFnzkxs27YFXbp0wPDhIzF9+iy4WfEFQ4iIjGGW0DUI2JPHYXvy\nONTXrhqMYcBWHX5+/li16ku88MKLmD59Cr744lPs2rUD8+YtRK9effnZXiJSrMoJ3T/+gOPev2Bz\n6oRxAdu4KcTAIAZsFdO+/ZPYu/cwVq5cjuXLl2DUqOFYv34tFi1agtDQcLnLIyIyucoJ3U6dkHcR\nQIOAzTsGy4ClXPb29pgy5Q306dMfM2ZMxR9//Ib27R/HuHGTMGHCFDgo9DKaRFQ1VU7oTp+OpPD6\nDFgyWu3aIdi4cSu+/34XZs16A0uXLsLWrd/lnvncW+7yiIhMotS12GJjYzFkyBA888wz6NGjB9as\nWVP6VhcuRHaPXjwmS2UiCAJ69HgOf/31D0aNeg03bsRg4MA+6NevH44d+wdSoWthExFZm1JDV61W\nY8aMGfjxxx+xadMmrF+/HpcvXy7ty4jKzcXFFfPmLcSvv+5HixaPYuvWrXj66U7o0OFxfPbZx0hI\nKLo+MRGRNSg1dH19fVGvXj0AgLOzM0JDQ3Hv3r1KL4yoQYOG+P77X/DTTz+hZ8/euHQpGrNmTUej\nRhEYNWoY9u/fC1EU5S6TiMhoZTqme/PmTVy4cAGNGjWqrHqIDKhUKnTt2hXNmrVGfHw8Nm/ehHXr\nvsb27VuxfftWBAUFIzJyCAYOjIS/f4Dc5RIRlUiQjDxQlpaWhhdffBGvvvoqnnrqqRLHBgej2BnI\nsWNpxY5v3ty52NflHK9SqYr0YE315ynYhyXUU57xeT3kjZckCX//fRTr13+DXbu2Iz39LADAwcER\nLi4ucHBwgCAIFlN/npgYFeKKWbnK0v/8C4/39XU16EPuesozvmAPllBPecf7+roiMLD4vT3WUD8A\ntGzpavV5Aej+fhvDqJmuRqPB+PHj8dxzz5UauHlUqqIF+Pq6PmRs8duQe3zhHuSup7zj8/qwlHrK\nM16lUhmMf/bZznj22c5ISkpCSIgKqakpyMzMQGZmBtRqNVxcXJCUdB9hYWEWUX9JX2MNf/6Fxxd8\nbAn1lGd83nNLqaf844v/AmupX/c11p8XxjJqpjtt2jR4enpixowZRm+4uP/RW5PC/5u3Vkrow9ge\nTp8+hQ0b1mDLlu+QlJQIAGjbth0iI4ege/eesn/mVwk/C0AZfSihB0AZfSihB6Dk/1QUVGpGHzt2\nDLt378aRI0fQq1cv9O7dG/v3769wgUSm1rBhIyxcuASnTkXh448/R5s2T+Dgwf0YM+ZlNGpUBzNn\nvo6zZ8/IXSYRVWFGH9MtK2v/n4uS/vdl7X1UpIcrVy5hw4Z12LhxHeLidGfdN2vWHJGRL6F3775w\ncTHuf6emoISfBaCMPpTQA6CMPpTQA2DCmS6RNQsJCcOsWXNx4sR5fPPNRnTp0g0nThzHlCnj0aBB\nHUyc+Br++ecoL7xBRGbB0KUqwdbWFk8/3R3r1n2H48fPYcaM2fDx8cWGDWvRvXtntGvXCqtWrcT9\n+7zwBhFVHoYuVTkBAdUxadLr+PvvE9i8eSd69eqDq1ev4K23ZqJRozoYOXIo9u79gxfeICKTk20R\neyK5qVQqtG//JNq3fxL379/Hli2bsH79GuzcuQ07d25DYGAQXnhhMF54YTCqV68hd7lEVMmys4H0\ndCA9XShwr3uclpb/WkZG0TEbNxr3PXgi1UMo6eC+tfdhzh4kScKxY/9g/fo12L59K9LT06BSqdCx\n41OIjHwJXbp0g62tbbm2rYSfBaCMPpTQA6CMPsrSgyiimMDLv8/IKDksC79XOEA1mvIv0GNsknKm\nS1SAIAho0eJRtGjxKObNW4gdO7Zh/fpv8Ntvv+C3336Br281DBgwCIMHD0FISNELbxCRjiQBaWlA\naqqAlBQBKSmGj9PSdI+1WiA+3v6hQVg4LE1BpZLg5AQ4OenuvbzEAs91rzk7S3B0zB9jeF/0NehX\nkS8ZZ7oPoYT/QQLK6MMSejh37iw2bFiDzZs34cGDBwCA1q3bIjJyCJ599jk4OjqWug1L6MMUlNCH\nEnoATN+HJAFZWXnhmB+SqanIDUvd4/zwzH+vuK+RpPKHpL19ySFX8N7R0XCMs/PDw9LREbC3N/2q\ns8Z+ZIih+xD8S2k5LKmHzMxM7NnzPdatW4MDB/YCANzc3NGv3/OIjHwJDRs+fDEQS+qjIpTQhxJ6\nAPL70GhQKAx1j4ubZRYOybzHea/n5JQvjezsJLi6SnBxAVxcdI9dXQFXVwnOzvmPdWN0z11cJNSs\n6YTs7LQiM0y12sR/WJWMoVtBSvtLac0stYdr165i48a12LhxPWJj7wAAGjduisjIIejTpx/c3NwN\nxltqH2WlhD4srQdJ0h2rTEwUkJgoICkp7x548CD/eeH309JUSE6Wyr3bVRAMw9DZuWAwokBA5j/P\nC1NdkOaHp719+Xq3tJ9FeTF0K0hJvwjW3oel96DRaPD7779i/fpv8OuvP0Or1cLR0RE9e/ZGZORL\naNXqMQiCYPF9GEsJfVRWD5mZKBSQMAjJwqFZ8P3sbOOD09ZWgru7BC8vFRwdtUVmjwXDMO/1ggGa\n956Tk+l3s5aVEn6fAIZuhSnpF8Ha+7CmHmJj7+Dbbzdg/fo1uHbtKgAgLCwckZEvYeTIobCzc5O5\nwoqzpp/Hw5TUQ04ODELz4YFZ9P3MTOMTTK2W4OEhwd0d8PCQ9Dd3d6nQ86Lv54Wl0n8W1oShW0FK\n+kWw9j6ssQdRFHHo0EGsW/cNfvhhF7KysgAAoaFhaN26rf4WEFBd5krLzlp+Hnlnz8bHC7h/X0B8\nvID4eBXu3xeQnm6PO3dy9KFZcBduWXbVCoIuFN3dJXh65gem4fOi73t46HbXVnSWaS0/i5IooQeA\noVthSvpFsPY+rL2HBw8SsG3bZhw48Cf27z+A1NT8XkJCQvUB3KbNE1YRwnL+PDIzDUM0Li7vsapQ\nuOoeZ2QYl2pubg+bZepC82GzUFfXsq+nakrW/ncDUEYPAEO3wpT0i2DtfSihB0DXx507D3DmzCn8\n9ddBHDp0AEeOHEZKSrJ+TO3aIWjT5gk8/ngbtGnzhEVeCcuUP4+cHCAhQReexYWmLlhV+sepqaWH\nqL29BB8fw5u3twQfH1H/PDTUCZKUCg8PCW5ugI2VXrFACX83lNADwNCtMCX9Ilh7H0roASi+D61W\naxDChw8fMgjh4ODaBiFco0ZNc5ddREk/D61Wd7Zt4QDNn5EWfE+FxMTSQ9TGJi808wPU17domOa9\n7uxc+m5bJf9OWRsl9AAwdCtMSb8I1t6HEnoAjOtDq9Xi7NnTBiGcnJykfz8oKNgghGvWrFXZZUOj\nAe7dE3DnjoDYWBUyMx1x7VpWkRlpfLyAhAQBolhy4qlUEry8Cs9Ciz7OC1N398q5kEFV+Z2ydEro\nAWDoVpiSfhGsvQ8l9ACUrw+tVotz587gr78O4NChgzh8+BCSkhL17wcGBqNNm/wTs2rVCizT9tPS\ngNhYAbdvq/SheueOgNu38x/fu1d6kHp46EKy5Bmp7ubpKcl+4YOq/DtlaZTQA8DQrTAl/SJYex9K\n6AEwTR+6ED6LQ4cO4K+/DuLw4b8KhXAQWrdui8cea4v69dtDrQ7MDVEVYmMF3LmjC1LdTYXk5IeH\nqZ2dBH9/CQEBIgICJAQESPD3FxEW5gBb23T4+OhC1ctLQjnXgJANf6cshxJ6AIwPXSs9fYCoalKr\n1ahTpxHc3BqjcePx6NVLwokT93DqVDyuXMnErVu22LTJD5s21QBg99DtuLtLqFFDRPPmulD195dQ\nvXr+44AA3ey0uN26vr4OiIvTVl6TRArG0CWyEJIEJCejwK5e3Wy04K7e2FjdCUiGaufedMdLfXyy\n4eAQj5yca0hMPI2srCsAbgG4CT8/EW3ahKB9+1Zo3botAgODIMh9SSKiKoShS2QGWi1w6xZw5oyq\nwK7egrt7da+VdGEGJyfdDLRuXU3uzFTM3eWrm6FWr67b3as7XuoKoCFE8RGcP38Ohw8fxF9/peDw\n4YPYtu0Atm37BgBQo0ZN/WeEW7dui6CgYIYwUSXiMd2HUNJxBmvvw1p6SEkBrl9X4do1Fa5fF3D9\nukr//ObNkldv8fExPG4aEKAL1bxdvQEBItzcKn4WryiKuHDhfG4IH8Thwwdx//59/fs1atTUnxnd\nunVbBAfXLhLC1vLzKIkSegCU0YcSegB4IlWFKekXwdr7sJQetFrdmb7Fher16wLu3y/+0kQ+PiKC\ngiSEhqrh5ZVtcGJSQIAIP7/yr9BSUaIoIirqAg4dOoBDh/7CoUMHDEK4evUaBiFcu3YIqlVzs4if\nR0VYyu9URSmhDyX0ADB0K0xJvwjW3oc5e0hNhT5M84JVF6oq3LhR/EowtrYSAgMlBAWJ+ltwcP5z\nFxfz91FekiQhKuoC/vrrAA4f1oVwfHy8/n1//wA0bdoEgYEhqFMnAuHhEQgPrwNvb28Zqy47a/hZ\nGEMJfSihB4BnLxMVSxR1s9W8UL12LT9Ur18v7iQlHW9vEQ0aFAxV3ew1KEg3a5X7c6emIggC6tat\nh7p162HEiFcgSRIuXozSh/CRI4ewZ8+eIl/n7e2tD+D8WwRq1qwFlZwXJyayMAxdUpz0dBjMVAvu\nAo6JUSErq+hs1cZGQq1aEho00BQJ1aAg3fHUqkgQBERE1EVERF0MHz4SAGBjo8GRI/8hOvoiLl6M\nwqVLuvu//z6CI0cOGXy9o6MjQkPDUadOnQKhHIGQkFDYy7VPnUhGDF2yOpIE3L1reGy14Gz13r3i\nZ1aenhLq1Ss6Uw0K0p35a60XvTc3T09PtGjxKFq0eNTg9czMTFy9egXR0VEFwvgiLl+OxpkzpwzG\nqlQqBAUFo06dCISF1cndVa2bIbu7e5izHSKz4j8zZJEkSbcb+MIFFWJjgTNn7PWhGhOjKnbJNrVa\nQs2aEtq10+hDVXevu7m7y9BIFeLg4IB69eqjXr36Bq+LooibN2/khvFF/cw4OjoKP/+8Bz//bLi7\nulo1v9wwDjc4bhwQUJ0fZyKrx9AlWUmS7mL6Fy6oEBWlu124oMbFiyokJRX8B1Z3dSU3Nwnh4WKB\nMJX0M9caNThbtUQqlQqBgUEIDAxCp05dDN67f/8+oqOj9Luqo6OjcOlSNA4e3I+DB/cbjHV2dkF4\neDjCwyMMZsjBwbVha23XoaQqi/9EkdnExeWHa37Iqoss76ZWSwgJEfHEEyIiIkS0bGkPb+80BAWJ\n8OCeR0Xx9vaGt3drPPZYa4PX09PTcflydIEw1s2Qz507ixMnjhuMtbGxQe3aIQXCOFx/7+Ji3Bml\nRObC0CWTu39fKBSsulvhz7GqVBJq15bQurUGdevqAjYiQkRoqGjwuVVfX3vExYlm7oLk5OTkhIYN\nG6Nhw8YGr2s0GsTEXEN0dLTBSVzR0RcRHX0RP/6422B89eo1DM6mzpsh+/i4mLMdIj2GLpXbgwdA\nVJS60K5hVZGP3QiChKAgCS1b5hiEa1iYCAcHmYonq2RjY4OQkDCEhISha9en9a9LkoR79+4VOYkr\nOjoK+/b9iX37/jTYjouLC/z9A+DvHwA/P//cez+D1/z8/OHk5GTuFknhGLpUqqQk4MIFtUGwRkWp\nij1LODBQRJcuGkREaBERIaJuXV248t8uqkyCIMDPzw9+fn5o27adwXupqSkFPt6kmyHfvHkdt2/f\nxqVL0SVu193dA/7+/vDzC4C/v39uKPvnhnL+Y378iYzF0CW9lBTkBqraIFxjY4uGa61aIp56SpM7\na9Wibl0R4eEinJ1lKJyoBC4urmjatDmaNm2ufy3vKkhZWVm4d+8u7t6NRWxsLO7evYPY2FjExt5B\nbOyd3NfvICrqQonfw8vLq5hgDjAI6WrV/HjCFzF0q6LUVODixfwzhfPC9fbtouFao4aIjh01ubNW\n3ey1Tp38SxsSWTN7e3vUqhWIWrUCSxyXkZGBe/fuFgjm/HDOC+Zbt27i/PmzD92GIAjw9vbRB3HB\nXdsFw9nHxxc2PA1fsfiTVbD0dODff4HDh230s9eoKBVu3CgargEBIjp00Oh3CeftHnblyZ9EcHR0\nRFBQMIKCgkscl5aWhrt3Y/VBnB/M+Y+vXLlc5GIhBalUKvj6Vis0Yy46g7a2612TDkNXITQa3a7h\n48fVOH5chf/+081gRREAHPXjqlUT8cQT+WcL581eeeEIoopzdnZGSEgoQkJCSxyXmppisBu74K7t\n/F3a53Hy5PGHbsPGxgZeXl5wc3OHu7s73N09Ctx76J97eHjAza3ovVopFwy3MgxdKyRJQEyMgOPH\n1fjvP13InjqlNrhKk5OThJYttWjZ0gaBgZn62aunp4yFExEA3XHmsDBXhIWFP3SMJElITk4q9hjz\n3bt3ERt7B8nJiUhIeICYmOvIzs4uUw2urm7FhLW7QVjnv+ZpEOCOjo68Olg5MXStQEICcOJEXsDq\nQrbgx3JUKgl164po1kyLZs1ENG2qm73a2OSdMJIjY/VEVB6CIOhnrBERdYsdk3dCmCRJyMjIQHJy\nEhITE5GUlISkpAe597rniYmJ+vcL3sfEXEdKSnKZarOzs9PPmjnLLhuGroXJyABOn87bTawL2mvX\nDI/BBgaKeO65HDRtqgvZhg21PGuYqAoTBAFOTk5wcnKCv39Amb9eq9UiOTnJIKSTkhILBHhigZth\nkF+/fg05OWX7j72rq5s+gH18vGBraw9HR139jo6OcHJy1t87ORV87lRgXP69s7Pu3hrCnKErI60W\niI5W4b//VPpZ7PnzKmg0+bttPD0ldOyoyQ1YLZo0EeHrK8lYNREpjVqthqenFzw9vcr8tWWdZRcM\n7piY6zh79rTJ+rC3ty8S2oXDOu/2sJA3DHvDcLezs6vwbnWGrplIEnD7tqA/Bnv8uBonTqiRlpb/\nA7S3l9CkiW43cdOmulvt2hJ46ISILFVFZ9ne3s6IibmH9PR0ZGSkF3ufd8vIyEB6elqh+4LjdK+l\npaUjOTkZsbGxyMhIhyia5jKyarW6SFjnzcT3799r1DZKDd2ZM2di79698Pb2xu7du0sbTrmSkqDf\nRZx3NnHBKzgJgoSICBFNm4r6WWzduiLs7GQsmojIzFQqFZydneFcScfIJElCVlZWgSDXBXZ6enEB\nXtmry4cAAAsRSURBVFyQFwx9w/vExESkp6eVafd6qaHbp08fvPjii5g2bVqFGleyrCzg7FmVwdnE\nly4ZHluoXl1E9+45aNpUN5Nt3FjLz8ASEVUyQRDg4OAABweHcu0+N4ZJQ7dFixa4detWhQpSElEE\nLl/WHYfNm8meOaNCTk7+PmBXV91C6rrdxLqZrL8/j8MSESlRWS7vyWO6pbh7V3ccNu9kpxMn1EhJ\nyQ9YW1sJDRqI+mOwzZrplqZTFb3oExERVXEM3QIkCTh/XoUDB9Q4fhw4csS5yPWIw8K06NYt/2Sn\nRx4xXPuViIjoYSotdH19reOA5Y0bwG+/6W6//w7cvZv/np+fCj17Ao8+CrRqBbRoAXh4qAGoAVjP\naiHW8rMoiRJ6ANiHJVFCD4Ay+lBCD8YyKnQlqezHI+PiUsr8NeaQlAQcPGiD/fvV2L/fBpcv589k\nq1UT0a+fFu3aadCzpyMcHVMMPq6TkwPExclQdAXkXbHGmimhB4B9WBIl9AAoow8l9AAY/x+HUkN3\nypQpOHr0KBITE9GhQweMGzcOffv2rXCB5pKVBfzzj1ofsidOqCCKuiR1dpbQpYsG7dpp0K6d7tKJ\neSHr62t9AUtERJat1NBdunSpOeowGVHUfXxn3z5dyB49mr8QgI2NbhGAdu10t2bNtOCa0kREZC6K\nOJHq+nUB+/frdhkfOKBGQkL+LuN69fJCVoPHH9dy8XUiIpKNVYZuQoLuuGzebPb69fyQrV5dxMCB\nOWjXToMnntDCz4+fjyUiIstgFaGbkQEcPZp/XPb0aRUkSbfL2M1NwjPP5Ohns6GhvFYxERFZJosM\nXa0WOHVKpd9l/PffamRl6ZLUzk5Cmzb5u4wbNdKtG0tERGTpLCKuJAm4elXAvn26kD140AZJSfnT\n1YYN80O2VSstnJxkLJaIiKicZAvde/cEHDyYv8v45s3847KBgSJ69tTtMm7TRgsfHx6XJSIi62e2\n0E1NBY4cUetns+fP56/C4+kp6UO2XTsNgoMZskREpDyVFro5OcDx4/nHZf/9Vw2NRrfL2MFBQvv2\nugtStG+vQYMGXCCAiIiUr1JCt2dP4M8/XZCaqgtZQZDQpImov/JTy5ZaODhUxncmIiKyXJUSurt3\nAyEhEvr10+0ybttWAw+PyvhORERE1qNSQvfaNcDJKa0yNk1ERGS1KuVIalBQZWyViIjIuvH0JSIi\nIjNh6BIREZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIR\nEZkJQ5eIiMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eI\niMhMGLpERERmwtAlIiIyE4YuERGRmTB0iYiIzIShS0REZCYMXSIiIjNh6BIREZkJQ5eIiMhMGLpE\nRERmwtAlIiIyE4YuERGRmRgVuvv370e3bt3QtWtXfPbZZ5VdExERkSKVGrqiKGLevHlYvXo1vv/+\ne/zwww+4fPmyOWojIiJSlFJD99SpUwgKCkKNGjVga2uL7t274/fffzdHbURERIpSaujevXsXAQEB\n+ud+fn64d+9epRZFRESkRKWGriRJ5qiDiIhI8WxKG+Dv74/bt2/rn9+9exfVqlUrdcO+vq4Vq8wC\nKKEHQBl9KKEHgH1YEiX0ACijDyX0YKxSZ7oNGzZETEwMbt26hezsbPzwww/o1KmTOWojIiJSlFJn\numq1GrNnz8bw4cMhSRL69euH0NBQc9RGRESkKILEg7ZERERmwStSERERmQlDl4iIyEwYukRERGZS\n6olUZbF//34sWLAAkiShb9++eOWVV0y5ebOYOXMm9u7dC29vb+zevVvucsolNjYW06ZNQ3x8PNRq\nNfr3748hQ4bIXVaZZWdnIzIyEjk5OdBqtejatSvGjh0rd1nlIooi+vbtCz8/P6xatUrucsqlY8eO\ncHFxgUqlgo2NDbZs2SJ3SeWSkpKCN998E9HR0VCpVFiwYAEaN24sd1lGu3r1KiZNmgRBECBJEm7c\nuIEJEyZY5d/xr7/+Glu2bIEgCKhTpw4W/r+9u3mJag8DOP6dHKRQexElCyzIjCySFr1AEyamSTXV\nxGCLNiVRbdIow14oghYJLfoHWkREEBEaRG1EszGmQiuGYIgwIhhMKkRT5yXPnOcu4l64G+89x7nz\na7rPZz1n+A6HmYcznHmmo4P8/HzTWY7cunXrr/fCv/qslQxJp9NSX18vsVhMfvz4IXv37pWhoaFM\nPX3WDAwMSDQaFb/fbzrFtS9fvkg0GhURkcnJSdmxY0dOngsRkXg8LiIilmVJU1OTRCIRw0Xu3Lx5\nU9ra2uT48eOmU1yrq6uTsbEx0xmzdvbsWbl//76IiExPT8vExIThIvfS6bT4fD4ZHh42neLYyMiI\n1NXVSSqVEhGRkydPSldXl+EqZ96/fy9+v19SqZRYliWHDx+WT58+zXhMxr5e/l12NG/YsIH58+eb\nzpiV0tJSqqqqACgoKKCioiJnV3fOmzcP+HnVa1mW4Rp3RkZGePr0KU1NTaZTZkVEsG3bdMasTE5O\nMjg4SDAYBMDr9VJYWGi4yr1wOMyyZcv+tqo3l9i2TSKRwLIsksnkv1q89Cv58OED69evJz8/n7y8\nPDZu3Eh3d/eMx2Rs6OqO5l9TLBbj3bt3VFdXm05xxbZtAoEAPp8Pn8+Xk6/j6tWrtLe34/F4TKfM\nisfj4ciRIwSDQe7du2c6x5VYLMaiRYs4f/48+/fv59KlSySTSdNZrj1+/Jjdu3ebznBl8eLFNDc3\nU1tbS01NDUVFRWzZssV0liOVlZUMDAwwPj5OIpEgFArx+fPnGY/J2NAV/bnvL2dqaorW1lYuXLhA\nQUGB6RxX5syZw4MHDwiFQkQiEYaGhkwnOdLX10dJSQlVVVU5/x65e/cunZ2d3Lhxgzt37jA4OGg6\nyTHLsohGoxw8eJCuri7mzp2bs/8RPj09TW9vLzt37jSd4sr379/p6enhyZMn9Pf3E4/Hc+4+moqK\nCo4ePUpzczPHjh1j9erVeL0z3yqVsaHrdkez+m9YlkVrayv79u2jvr7edM6sFRYWsmnTJvr7+02n\nOPL69Wt6e3vZvn07bW1tvHz5kvb2dtNZrpSWlgJQXFxMQ0MDb9++NVzkXFlZGWVlZaxbtw6AxsZG\notGo4Sp3QqEQa9eupbi42HSKK+FwmPLychYuXEheXh4NDQ28efPGdJZjwWCQzs5Obt++zYIFC1i+\nfPmMj8/Y0P2ddjTn+hUJ/LwLe+XKlRw6dMh0imujo6NMTEwAkEwmef78OStWrDBc5czp06fp6+uj\np6eH69evs3nzZq5du2Y6y7FEIsHU1BQA8XicZ8+eUVlZabjKuZKSEpYsWcLHjx8BePHiRc6utX30\n6BF+v990hmtLly4lEomQSqUQkZw9F6OjowAMDw/T3d39j+ckYz8Z+l12NP95NTI2NkZtbS0tLS1/\n3XSRK169esXDhw9ZtWoVgUAAj8fDqVOnqKmpMZ3myNevXzl37hy2bWPbNrt27WLbtm2ms/6Xvn37\nxokTJ/B4PKTTafbs2cPWrVtNZ7ly8eJFzpw5g2VZlJeX09HRYTrJsWQySTgc5sqVK6ZTXKuurqax\nsZFAIIDX62XNmjUcOHDAdJZjLS0tjI+P4/V6uXz5MkVFM/9jku5eVkoppbJEN1IppZRSWaJDVyml\nlMoSHbpKKaVUlujQVUoppbJEh65SSimVJTp0lVJKqSzRoauUUkpliQ5dpZRSKkv+AO2e4yf8wTuC\nAAAAAElFTkSuQmCC\n", + "text/plain": [ + "\u003cmatplotlib.figure.Figure at 0xc1dc310\u003e" + ] + }, + "metadata": { + "tags": [] + }, + "output_type": "display_data" + } + ], + "source": [ + "# Train our variables.\n", + "\n", + "# numpy is used for its asscalar() function.\n", + "import numpy as np\n", + "\n", + "num_training_steps = 10\n", + "\n", + "def train_model(inputs, labels, wb, optimizer, num_training_steps):\n", + " loss_at_step = []\n", + " w_at_step = []\n", + " b_at_step = []\n", + " for step_num in range(num_training_steps):\n", + " loss, gradients_and_variables = value_and_gradients_fn(inputs, labels, wb)\n", + " loss_at_step.append(np.asscalar(loss.numpy()))\n", + " \n", + " optimizer.apply_gradients(gradients_and_variables)\n", + " w, b = wb.variables\n", + " w_at_step.append(np.asscalar(w.read_value().numpy()))\n", + " b_at_step.append(np.asscalar(b.read_value().numpy()))\n", + "\n", + " print(w_at_step)\n", + " t = range(0, num_training_steps)\n", + " plt.plot(t, loss_at_step, 'k',\n", + " t, w_at_step, 'r',\n", + " t, [true_w] * num_training_steps, 'r--',\n", + " t, b_at_step, 'b',\n", + " t, [true_b] * num_training_steps, 'b--')\n", + " plt.legend(['loss', 'w estimate', 'w true', 'b estimate', 'b true'])\n", + " plt.show()\n", + "\n", + "train_model(inputs, labels, wb, optimizer, num_training_steps)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "UNurY9VJ-hpH" + }, + "source": [ + "## Other Ways to Compute Gradients\n", + "\n", + "Using our loss function as an example (`calculate_linear_model_loss()`), there are several other ways we could compute gradients:\n", + "\n", + "1. `tfe.implicit_gradients()`\n", + "1. `tfe.gradients_function()`\n", + "1. `tfe.implicit_value_and_gradients()`\n", + "1. `tfe.value_and_gradients_function()`\n", + "\n", + "Each of these functions does the following:\n", + "* Wraps a function.\n", + "* Returns a function with the same input signature as the wrapped function.\n", + "\n", + "They differ only in what information they return.\n", + "\n", + "### Gradients-only functions\n", + "\n", + "The following two functions return a function that returns only the variables' gradients:\n", + "\n", + "1. `tfe.gradients_function()`: Returns the partial derivatives of the function `f()` with respect to the parameters of `f()`.\n", + "1. `tfe.implicit_gradients()`: Returns the partial derivatives of the function `f()` with respect to the trainable parameters (`tf.Variable`) used by `f()`.\n", + "\n", + "In our example above, the `tf.layers.Dense` object encapsulates the trainable parameters.\n", + "\n", + "### Value and gradients functions\n", + "\n", + "The following two functions are identical to their counterparts above, except that they also return the value of the wrapped function.\n", + "\n", + "1. `tfe.implicit_value_and_gradients()`\n", + "1. `tfe.value_and_gradients_function()`\n", + "\n", + "### Gradient demos\n", + "\n", + "In the demos below, we show examples for the `implicit_*` functions, since our existing loss function works seamlessly with these versions. (The other versions require that your parameters are tensors and tensors only; in our example, we're using a `Dense` layer.)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 85, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 100, + "status": "ok", + "timestamp": 1505502831671, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "aEoCftnfAIH5", + "outputId": "72f1c1dc-a574-463f-f860-c4e5f48fcdaa" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[(\u003ctf.Tensor: id=673, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", + " (\u003ctf.Tensor: id=671, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)]" + ] + }, + "execution_count": 13, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# tfe.implicit_gradients() demo\n", + "gradients_fn = tfe.implicit_gradients(loss_fn)\n", + "\n", + "# Returns only gradients and variables:\n", + "gradients_fn(inputs, labels, wb)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 102, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 88, + "status": "ok", + "timestamp": 1505502831785, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 240 + }, + "id": "bbgCUdCzAVhH", + "outputId": "152aa9b6-9e42-4b7e-848a-9423c0b1929c" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(\u003ctf.Tensor: id=688, shape=(), dtype=float32, numpy=1.0623235\u003e,\n", + " [(\u003ctf.Tensor: id=720, shape=(1, 1), dtype=float32, numpy=array([[-0.26846504]], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense/kernel:0' shape=(1, 1) dtype=float32\u003e),\n", + " (\u003ctf.Tensor: id=718, shape=(1,), dtype=float32, numpy=array([-0.32890949], dtype=float32)\u003e,\n", + " \u003ctf.Variable 'dense/bias:0' shape=(1,) dtype=float32\u003e)])" + ] + }, + "execution_count": 14, + "metadata": { + "tags": [] + }, + "output_type": "execute_result" + } + ], + "source": [ + "# tfe.implicit_value_and_gradients() demo\n", + "value_gradients_fn = tfe.implicit_value_and_gradients(loss_fn)\n", + "\n", + "# Returns only gradients:\n", + "value_gradients_fn(inputs, labels, wb)" + ] + } + ], + "metadata": { + "colab": { + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Eager Execution Tutorial: Working with Gradients", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ebcc7027c1d34c47a339a49ede1d80e58ad43780 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/notebooks/3_datasets.ipynb @@ -0,0 +1,218 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "U9i2Dsh-ziXr" + }, + "source": [ + "# Eager Execution Tutorial: Importing Data\n", + "\n", + "This notebook demonstrates the use of the [`tf.contrib.data.Dataset` API](https://www.tensorflow.org/programmers_guide/datasets) to build pipelines to feed data to your program. It covers:\n", + "\n", + "* Creating a `Dataset`.\n", + "* Iteration over a `Dataset` with eager execution enabled.\n", + "\n", + "We recommend using the `Dataset`s API for building performant, complex input pipelines from simple, re-usable pieces that will feed your model's training or evaluation loops.\n", + "\n", + "If you're familiar with TensorFlow graphs, the API for constructing the `Dataset` object remains exactly the same when eager execution is enabled, but the process of iterating over elements of the dataset is slightly different. You will use a Pythonic `Iterator()` class instead of using `make_one_shot_iterator()` and `get_next()`. As a result, the discussion on iterators in the [Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is not relevant when eager execution is enabled." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "z1JcS5iBXMRO" + }, + "source": [ + "# Setup: Enable eager execution\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "RlIWhyeLoYnG" + }, + "outputs": [], + "source": [ + "# Import TensorFlow.\n", + "import tensorflow as tf\n", + "\n", + "# Import TensorFlow eager execution support (subject to future changes).\n", + "import tensorflow.contrib.eager as tfe\n", + "\n", + "# Enable eager execution\n", + "tfe.enable_eager_execution()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "H9UySOPLXdaw" + }, + "source": [ + "# Step 1: Create a source `Dataset`\n", + "\n", + "Create a _source_ dataset using one of the factory functions like [`Dataset.from_tensors`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensors), [`Dataset.from_tensor_slices`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#from_tensor_slices) or using objects that read from files like [`TextLineDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TextLineDataset) or [`TFRecordDataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/TFRecordDataset). See the [Programmer's Guide](https://www.google.com/url?sa=D\u0026q=https%3A%2F%2Fwww.tensorflow.org%2Fprogrammers_guide%2Fdatasets%23reading_input_data) for more information." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "WPTUfGq6kJ5w" + }, + "outputs": [], + "source": [ + "ds_tensors = tf.contrib.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6])\n", + "\n", + "# Create a CSV file\n", + "import tempfile\n", + "_, filename = tempfile.mkstemp()\n", + "with open(filename, 'w') as f:\n", + " f.write(\"\"\"Line 1\n", + "Line 2\n", + "Line 3\n", + " \"\"\")\n", + "ds_file = tf.contrib.data.TextLineDataset(filename)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "twBfWd5xyu_d" + }, + "source": [ + "# Step 2: Apply transformations\n", + "\n", + "Use the transformations functions like [`map`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#map), [`batch`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#batch), [`shuffle`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset#shuffle) etc. to apply transformations to the records of the dataset. See the [API documentation for `tf.contrib.data.Dataset`](https://www.tensorflow.org/api_docs/python/tf/contrib/data/Dataset) for details." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "cellView": "code", + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + } + }, + "colab_type": "code", + "id": "ngUe237Wt48W" + }, + "outputs": [], + "source": [ + "ds_tensors = ds_tensors.map(tf.square).shuffle(2).batch(2)\n", + "ds_file = ds_file.batch(2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "IDY4WsYRhP81" + }, + "source": [ + "# Step 3: Iterate\n", + "\n", + "Use `tfe.Iterator` on the `Dataset` object to get a Python iterator over the contents of the dataset.\n", + "\n", + "If you're familiar with the use of `Dataset`s in TensorFlow graphs, note that this process of iteration is different. Here there are no calls to `Dataset.make_one_shot_iterator()` and no `get_next()` calls." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "colab": { + "autoexec": { + "startup": false, + "wait_interval": 0 + }, + "height": 153, + "output_extras": [ + { + "item_id": 1 + } + ] + }, + "colab_type": "code", + "executionInfo": { + "elapsed": 201, + "status": "ok", + "timestamp": 1505952405928, + "user": { + "displayName": "", + "photoUrl": "", + "userId": "" + }, + "user_tz": 420 + }, + "id": "lCUWzso6mbqR", + "outputId": "ec027d30-96c6-4ea4-9ee1-ef74ec1ae29a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Elements of ds_tensors:\n", + "tf.Tensor([4 9], shape=(2,), dtype=int32)\n", + "tf.Tensor([16 25], shape=(2,), dtype=int32)\n", + "tf.Tensor([36 1], shape=(2,), dtype=int32)\n", + "\n", + "Elements in ds_file:\n", + "tf.Tensor(['Line 1' 'Line 2'], shape=(2,), dtype=string)\n", + "tf.Tensor(['Line 3' ' '], shape=(2,), dtype=string)\n" + ] + } + ], + "source": [ + "print('Elements of ds_tensors:')\n", + "for x in tfe.Iterator(ds_tensors):\n", + " print(x)\n", + "\n", + "print('\\nElements in ds_file:')\n", + "for x in tfe.Iterator(ds_file):\n", + " print(x)" + ] + } + ], + "metadata": { + "colab": { + "default_view": {}, + "last_runtime": { + "build_target": "", + "kind": "local" + }, + "name": "Eager Execution Tutorial: Importing Data", + "provenance": [], + "version": "0.3.2", + "views": {} + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..536cad998d94e45187d30fce3be0d7a57178e0c1 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD @@ -0,0 +1,44 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_library( + name = "resnet50", + srcs = ["resnet50.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +cuda_py_test( + name = "resnet50_test", + size = "large", + srcs = ["resnet50_test.py"], + additional_deps = [ + ":resnet50", + "//tensorflow/contrib/summary:summary_test_util", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "resnet50_graph_test", + size = "large", + srcs = ["resnet50_graph_test.py"], + additional_deps = [ + ":resnet50", + "//tensorflow/contrib/summary:summary_test_util", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], + tags = [ + "noasan", + "nomsan", + "notsan", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/resnet50/README.md b/tensorflow/contrib/eager/python/examples/resnet50/README.md new file mode 100644 index 0000000000000000000000000000000000000000..db023e6c976c8eda09ef0dee7eecb144678773c4 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/resnet50/README.md @@ -0,0 +1,45 @@ +Image classification using the ResNet50 model described in +[Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385). + +Contents: + +- `resnet50.py`: Model definition +- `resnet50_test.py`: Sanity unittests and benchmarks for using the model with + eager execution enabled. +- `resnet50_graph_test.py`: Sanity unittests and benchmarks when using the same + model code to construct a TensorFlow graph. + +# Benchmarks + +Using a synthetic data, run: + +``` +# Using eager execution +python resnet50_test.py --benchmarks=. + +# Using graph execution +python resnet50_graph_test.py --benchmarks=. +``` + +The above uses the model definition included with the TensorFlow pip +package. To build (and run benchmarks) from source: + +``` +# Using eager execution +bazel run -c opt --config=cuda :resnet50_test -- --benchmarks=. + +# Using graph execution +bazel run -c opt --config=cuda :resnet50_graph_test -- --benchmarks=. +``` + +(Or remove the `--config=cuda` flag for running on CPU instead of GPU). + +On October 31, 2017, the benchmarks demostrated comparable performance +for eager and graph execution of this particular model when using +a single NVIDIA Titan X (Pascal) GPU on a host with an +Intel Xeon E5-1650 CPU @ 3.50GHz and a batch size of 32. + +| Benchmark name | batch size | images/second | +| --------------------------------------- | ------------- | ------------- | +| eager_train_gpu_batch_32_channels_first | 32 | 171 | +| graph_train_gpu_batch_32_channels_first | 32 | 172 | diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py new file mode 100644 index 0000000000000000000000000000000000000000..b302a87e0e8a61d2456db1eba847f31bd70f552e --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50.py @@ -0,0 +1,324 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ResNet50 model definition compatible with TensorFlow's eager execution. + +Reference [Deep Residual Learning for Image +Recognition](https://arxiv.org/abs/1512.03385) + +Adapted from tf.keras.applications.ResNet50. A notable difference is that the +model here outputs logits while the Keras model outputs probability. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +import tensorflow as tf +import tensorflow.contrib.eager as tfe + + +class _IdentityBlock(tfe.Network): + """_IdentityBlock is the block that has no conv layer at shortcut. + + Args: + kernel_size: the kernel size of middle conv layer at main path + filters: list of integers, the filters of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + data_format: data_format for the input ('channels_first' or + 'channels_last'). + """ + + def __init__(self, kernel_size, filters, stage, block, data_format): + super(_IdentityBlock, self).__init__(name='') + filters1, filters2, filters3 = filters + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + bn_axis = 1 if data_format == 'channels_first' else 3 + + self.conv2a = self.track_layer( + tf.layers.Conv2D( + filters1, (1, 1), + name=conv_name_base + '2a', + data_format=data_format)) + self.bn2a = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) + + self.conv2b = self.track_layer( + tf.layers.Conv2D( + filters2, + kernel_size, + padding='same', + data_format=data_format, + name=conv_name_base + '2b')) + self.bn2b = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) + + self.conv2c = self.track_layer( + tf.layers.Conv2D( + filters3, (1, 1), + name=conv_name_base + '2c', + data_format=data_format)) + self.bn2c = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + + def call(self, input_tensor, training=False): + x = self.conv2a(input_tensor) + x = self.bn2a(x, training=training) + x = tf.nn.relu(x) + + x = self.conv2b(x) + x = self.bn2b(x, training=training) + x = tf.nn.relu(x) + + x = self.conv2c(x) + x = self.bn2c(x, training=training) + + x += input_tensor + return tf.nn.relu(x) + + +class _ConvBlock(tfe.Network): + """_ConvBlock is the block that has a conv layer at shortcut. + + Args: + kernel_size: the kernel size of middle conv layer at main path + filters: list of integers, the filterss of 3 conv layer at main path + stage: integer, current stage label, used for generating layer names + block: 'a','b'..., current block label, used for generating layer names + data_format: data_format for the input ('channels_first' or + 'channels_last'). + strides: strides for the convolution. Note that from stage 3, the first + conv layer at main path is with strides=(2,2), and the shortcut should + have strides=(2,2) as well. + """ + + def __init__(self, + kernel_size, + filters, + stage, + block, + data_format, + strides=(2, 2)): + super(_ConvBlock, self).__init__(name='') + filters1, filters2, filters3 = filters + + conv_name_base = 'res' + str(stage) + block + '_branch' + bn_name_base = 'bn' + str(stage) + block + '_branch' + bn_axis = 1 if data_format == 'channels_first' else 3 + + self.conv2a = self.track_layer( + tf.layers.Conv2D( + filters1, (1, 1), + strides=strides, + name=conv_name_base + '2a', + data_format=data_format)) + self.bn2a = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')) + + self.conv2b = self.track_layer( + tf.layers.Conv2D( + filters2, + kernel_size, + padding='same', + name=conv_name_base + '2b', + data_format=data_format)) + self.bn2b = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')) + + self.conv2c = self.track_layer( + tf.layers.Conv2D( + filters3, (1, 1), + name=conv_name_base + '2c', + data_format=data_format)) + self.bn2c = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')) + + self.conv_shortcut = self.track_layer( + tf.layers.Conv2D( + filters3, (1, 1), + strides=strides, + name=conv_name_base + '1', + data_format=data_format)) + self.bn_shortcut = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '1')) + + def call(self, input_tensor, training=False): + x = self.conv2a(input_tensor) + x = self.bn2a(x, training=training) + x = tf.nn.relu(x) + + x = self.conv2b(x) + x = self.bn2b(x, training=training) + x = tf.nn.relu(x) + + x = self.conv2c(x) + x = self.bn2c(x, training=training) + + shortcut = self.conv_shortcut(input_tensor) + shortcut = self.bn_shortcut(shortcut, training=training) + + x += shortcut + return tf.nn.relu(x) + + +class ResNet50(tfe.Network): + """Instantiates the ResNet50 architecture. + + Args: + data_format: format for the image. Either 'channels_first' or + 'channels_last'. 'channels_first' is typically faster on GPUs while + 'channels_last' is typically faster on CPUs. See + https://www.tensorflow.org/performance/performance_guide#data_formats + name: Prefix applied to names of variables created in the model. + trainable: Is the model trainable? If true, performs backward + and optimization after call() method. + include_top: whether to include the fully-connected layer at the top of the + network. + pooling: Optional pooling mode for feature extraction when `include_top` + is `False`. + - `None` means that the output of the model will be the 4D tensor + output of the last convolutional layer. + - `avg` means that global average pooling will be applied to the output of + the last convolutional layer, and thus the output of the model will be + a 2D tensor. + - `max` means that global max pooling will be applied. + classes: optional number of classes to classify images into, only to be + specified if `include_top` is True. + + Raises: + ValueError: in case of invalid argument for data_format. + """ + + def __init__(self, + data_format, + name=None, + trainable=True, + include_top=True, + pooling=None, + classes=1000): + super(ResNet50, self).__init__(name='') + + valid_channel_values = ('channels_first', 'channels_last') + if data_format not in valid_channel_values: + raise ValueError('Unknown data_format: %s. Valid values: %s' % + (data_format, valid_channel_values)) + self.include_top = include_top + + def conv_block(filters, stage, block, strides=(2, 2)): + l = _ConvBlock( + 3, + filters, + stage=stage, + block=block, + data_format=data_format, + strides=strides) + return self.track_layer(l) + + def id_block(filters, stage, block): + l = _IdentityBlock( + 3, filters, stage=stage, block=block, data_format=data_format) + return self.track_layer(l) + + self.conv1 = self.track_layer( + tf.layers.Conv2D( + 64, (7, 7), + strides=(2, 2), + data_format=data_format, + padding='same', + name='conv1')) + bn_axis = 1 if data_format == 'channels_first' else 3 + self.bn_conv1 = self.track_layer( + tf.layers.BatchNormalization(axis=bn_axis, name='bn_conv1')) + self.max_pool = self.track_layer( + tf.layers.MaxPooling2D((3, 3), strides=(2, 2), data_format=data_format)) + + self.l2a = conv_block([64, 64, 256], stage=2, block='a', strides=(1, 1)) + self.l2b = id_block([64, 64, 256], stage=2, block='b') + self.l2c = id_block([64, 64, 256], stage=2, block='c') + + self.l3a = conv_block([128, 128, 512], stage=3, block='a') + self.l3b = id_block([128, 128, 512], stage=3, block='b') + self.l3c = id_block([128, 128, 512], stage=3, block='c') + self.l3d = id_block([128, 128, 512], stage=3, block='d') + + self.l4a = conv_block([256, 256, 1024], stage=4, block='a') + self.l4b = id_block([256, 256, 1024], stage=4, block='b') + self.l4c = id_block([256, 256, 1024], stage=4, block='c') + self.l4d = id_block([256, 256, 1024], stage=4, block='d') + self.l4e = id_block([256, 256, 1024], stage=4, block='e') + self.l4f = id_block([256, 256, 1024], stage=4, block='f') + + self.l5a = conv_block([512, 512, 2048], stage=5, block='a') + self.l5b = id_block([512, 512, 2048], stage=5, block='b') + self.l5c = id_block([512, 512, 2048], stage=5, block='c') + + self.avg_pool = self.track_layer( + tf.layers.AveragePooling2D( + (7, 7), strides=(7, 7), data_format=data_format)) + + if self.include_top: + self.fc1000 = self.track_layer( + tf.layers.Dense(classes, name='fc1000')) + else: + reduction_indices = [1, 2] if data_format == 'channels_last' else [2, 3] + reduction_indices = tf.constant(reduction_indices) + if pooling == 'avg': + self.global_pooling = functools.partial( + tf.reduce_mean, + reduction_indices=reduction_indices, + keep_dims=False) + elif pooling == 'max': + self.global_pooling = functools.partial( + tf.reduce_max, reduction_indices=reduction_indices, keep_dims=False) + else: + self.global_pooling = None + + def call(self, input_tensor, training=False): + x = self.conv1(input_tensor) + x = self.bn_conv1(x, training=training) + x = tf.nn.relu(x) + x = self.max_pool(x) + + x = self.l2a(x, training=training) + x = self.l2b(x, training=training) + x = self.l2c(x, training=training) + + x = self.l3a(x, training=training) + x = self.l3b(x, training=training) + x = self.l3c(x, training=training) + x = self.l3d(x, training=training) + + x = self.l4a(x, training=training) + x = self.l4b(x, training=training) + x = self.l4c(x, training=training) + x = self.l4d(x, training=training) + x = self.l4e(x, training=training) + x = self.l4f(x, training=training) + + x = self.l5a(x, training=training) + x = self.l5b(x, training=training) + x = self.l5c(x, training=training) + + x = self.avg_pool(x) + + if self.include_top: + return self.fc1000(tf.layers.flatten(x)) + elif self.global_pooling: + return self.global_pooling(x) + else: + return x diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..736a75332ff6403ea1b21387211df6b8fb6034f3 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -0,0 +1,163 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and benchmarks for ResNet50 under graph execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.resnet50 import resnet50 +from tensorflow.contrib.summary import summary_test_util + + +def data_format(): + return 'channels_first' if tf.test.is_gpu_available() else 'channels_last' + + +def image_shape(batch_size): + if data_format() == 'channels_first': + return [batch_size, 3, 224, 224] + return [batch_size, 224, 224, 3] + + +def random_batch(batch_size): + images = np.random.rand(*image_shape(batch_size)).astype(np.float32) + num_classes = 1000 + labels = np.random.randint( + low=0, high=num_classes, size=[batch_size]).astype(np.int32) + one_hot = np.zeros((batch_size, num_classes)).astype(np.float32) + one_hot[np.arange(batch_size), labels] = 1. + return images, one_hot + + +class ResNet50GraphTest(tf.test.TestCase): + + def testApply(self): + batch_size = 64 + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = resnet50.ResNet50(data_format()) + predictions = model(images) + + init = tf.global_variables_initializer() + + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + out = sess.run(predictions, feed_dict={images: np_images}) + self.assertAllEqual([64, 1000], out.shape) + + def testTrainWithSummary(self): + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None), name='images') + labels = tf.placeholder(tf.float32, [None, 1000], name='labels') + + tf.train.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with tf.contrib.summary.always_record_summaries(): + with tf.contrib.summary.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(): + model = resnet50.ResNet50(data_format()) + logits = model(images, training=True) + loss = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + tf.contrib.summary.scalar(name='loss', tensor=loss) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + train_op = optimizer.minimize(loss) + + init = tf.global_variables_initializer() + self.assertEqual(321, len(tf.global_variables())) + + batch_size = 32 + with tf.Session() as sess: + sess.run(init) + sess.run(tf.contrib.summary.summary_writer_initializer_op()) + np_images, np_labels = random_batch(batch_size) + sess.run([train_op, tf.contrib.summary.all_summary_ops()], + feed_dict={images: np_images, labels: np_labels}) + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'loss') + + +class ResNet50Benchmarks(tf.test.Benchmark): + + def _report(self, label, start, num_iters, batch_size): + avg_time = (time.time() - start) / num_iters + dev = 'gpu' if tf.test.is_gpu_available() else 'cpu' + name = 'graph_%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format()) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def benchmark_graph_apply(self): + with tf.Graph().as_default(): + images = tf.placeholder(tf.float32, image_shape(None)) + model = resnet50.ResNet50(data_format()) + predictions = model(images) + + init = tf.global_variables_initializer() + + batch_size = 64 + with tf.Session() as sess: + sess.run(init) + np_images, _ = random_batch(batch_size) + num_burn, num_iters = (3, 30) + for _ in range(num_burn): + sess.run(predictions, feed_dict={images: np_images}) + start = time.time() + for _ in range(num_iters): + # Comparison with the eager execution benchmark in resnet50_test.py + # isn't entirely fair as the time here includes the cost of copying + # the feeds from CPU memory to GPU. + sess.run(predictions, feed_dict={images: np_images}) + self._report('apply', start, num_iters, batch_size) + + def benchmark_graph_train(self): + for batch_size in [16, 32, 64]: + with tf.Graph().as_default(): + np_images, np_labels = random_batch(batch_size) + dataset = tf.data.Dataset.from_tensors((np_images, np_labels)).repeat() + (images, labels) = dataset.make_one_shot_iterator().get_next() + + model = resnet50.ResNet50(data_format()) + logits = model(images, training=True) + loss = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + train_op = optimizer.minimize(loss) + + init = tf.global_variables_initializer() + with tf.Session() as sess: + sess.run(init) + (num_burn, num_iters) = (5, 10) + for _ in range(num_burn): + sess.run(train_op) + start = time.time() + for _ in range(num_iters): + sess.run(train_op) + self._report('train', start, num_iters, batch_size) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d6389f2e385b3637b178d49fc56e8baf913eccaa --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py @@ -0,0 +1,234 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests and benchmarks for the ResNet50 model, executed eagerly.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import tempfile +import time + +import tensorflow as tf + +import tensorflow.contrib.eager as tfe +from tensorflow.contrib.eager.python.examples.resnet50 import resnet50 +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.client import device_lib + + +def device_and_data_format(): + return ('/gpu:0', 'channels_first') if tfe.num_gpus() else ('/cpu:0', + 'channels_last') + + +def random_batch(batch_size): + _, data_format = device_and_data_format() + + shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3) + shape = (batch_size,) + shape + + num_classes = 1000 + images = tf.random_uniform(shape) + labels = tf.random_uniform( + [batch_size], minval=0, maxval=num_classes, dtype=tf.int32) + one_hot = tf.one_hot(labels, num_classes) + + return images, one_hot + + +def train_one_step(model, images, labels, optimizer): + + def model_loss(): + logits = model(images, training=True) + loss = tf.losses.softmax_cross_entropy( + logits=logits, onehot_labels=labels) + tf.contrib.summary.scalar(name='loss', tensor=loss) + return loss + + optimizer.minimize(model_loss) + + +class ResNet50Test(tf.test.TestCase): + + def test_apply(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format) + with tf.device(device): + images, _ = random_batch(2) + output = model(images) + self.assertEqual((2, 1000), output.shape) + + def test_apply_no_top(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format, include_top=False) + with tf.device(device): + images, _ = random_batch(2) + output = model(images) + output_shape = ((2, 2048, 1, 1) + if data_format == 'channels_first' else (2, 1, 1, 2048)) + self.assertEqual(output_shape, output.shape) + + def test_apply_with_pooling(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format, include_top=False, pooling='avg') + with tf.device(device): + images, _ = random_batch(2) + output = model(images) + self.assertEqual((2, 2048), output.shape) + + def test_train(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format) + tf.train.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with tf.contrib.summary.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(), tf.contrib.summary.always_record_summaries(): + with tf.device(device): + optimizer = tf.train.GradientDescentOptimizer(0.1) + images, labels = random_batch(2) + train_one_step(model, images, labels, optimizer) + self.assertEqual(320, len(model.variables)) + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'loss') + + def test_no_garbage(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + with tf.device(device): + images, labels = random_batch(2) + gc.disable() + # Warm up. Note that this first run does create significant amounts of + # garbage to be collected. The hope is that this is a build-only effect, + # and a subsequent training loop will create nothing which needs to be + # collected. + train_one_step(model, images, labels, optimizer) + gc.collect() + previous_gc_debug_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + for _ in range(2): + # Run twice to ensure that garbage that is created on the first + # iteration is no longer accessible. + train_one_step(model, images, labels, optimizer) + gc.collect() + # There should be no garbage requiring collection. + self.assertEqual(0, len(gc.garbage)) + gc.set_debug(previous_gc_debug_flags) + gc.enable() + + +class MockIterator(object): + + def __init__(self, tensors): + self._tensors = [tf.identity(x) for x in tensors] + + def next(self): + return self._tensors + + +class ResNet50Benchmarks(tf.test.Benchmark): + + def _train_batch_sizes(self): + """Choose batch sizes based on GPU capability.""" + for device in device_lib.list_local_devices(): + if 'GPU:0' in device.name: + # Avoid OOM errors with larger batch sizes, which seem to cause errors + # later on even if caught. + # + # TODO(allenl): Base this on device memory; memory limit information + # during the test seems to exclude the amount TensorFlow has allocated, + # which isn't useful. + if 'K20' in device.physical_device_desc: + return (16,) + if 'P100' in device.physical_device_desc: + return (16, 32, 64) + return (16, 32) + + def _report(self, label, start, num_iters, device, batch_size, data_format): + avg_time = (time.time() - start) / num_iters + dev = 'cpu' if 'cpu' in device else 'gpu' + name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format) + extras = {'examples_per_sec': batch_size / avg_time} + self.report_benchmark( + iters=num_iters, wall_time=avg_time, name=name, extras=extras) + + def _force_gpu_sync(self): + # If this function is called in the context of a GPU device + # (e.g., inside a 'with tf.device("/gpu:0")' block) + # then this will force a copy from CPU->GPU->CPU, which forces + # a sync. This is a roundabout way, yes. + tf.constant(1.).cpu() + + def benchmark_eager_apply(self): + device, data_format = device_and_data_format() + model = resnet50.ResNet50(data_format) + batch_size = 64 + num_burn = 5 + num_iters = 30 + with tf.device(device): + images, _ = random_batch(batch_size) + for _ in xrange(num_burn): + model(images).cpu() + gc.collect() + start = time.time() + for _ in xrange(num_iters): + model(images).cpu() + self._report('eager_apply', start, num_iters, device, batch_size, + data_format) + + def _benchmark_eager_train(self, label, make_iterator): + device, data_format = device_and_data_format() + for batch_size in self._train_batch_sizes(): + (images, labels) = random_batch(batch_size) + num_burn = 3 + num_iters = 10 + model = resnet50.ResNet50(data_format) + optimizer = tf.train.GradientDescentOptimizer(0.1) + + with tf.device(device): + iterator = make_iterator((images, labels)) + for _ in xrange(num_burn): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + self._force_gpu_sync() + gc.collect() + + start = time.time() + for _ in xrange(num_iters): + (images, labels) = iterator.next() + train_one_step(model, images, labels, optimizer) + self._force_gpu_sync() + self._report(label, start, num_iters, device, batch_size, data_format) + + def benchmark_eager_train(self): + self._benchmark_eager_train('eager_train', MockIterator) + + def benchmark_eager_train_datasets(self): + + def make_iterator(tensors): + with tf.device('/device:CPU:0'): + ds = tf.data.Dataset.from_tensors(tensors).repeat() + return tfe.Iterator(ds) + + self._benchmark_eager_train('eager_train_dataset', make_iterator) + + +if __name__ == '__main__': + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..b657d31f35bafd6624ac7e4d6a6f6b2db362649d --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/BUILD @@ -0,0 +1,26 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "rnn_colorbot", + srcs = ["rnn_colorbot.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "rnn_colorbot_test", + srcs = ["rnn_colorbot_test.py"], + additional_deps = [ + ":rnn_colorbot", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md new file mode 100644 index 0000000000000000000000000000000000000000..fabd7b3e206d3a1954893a2b75361146d4709d00 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/README.md @@ -0,0 +1,26 @@ +RNN Colorbot: An RNN that predicts colors using eager execution. + +To train and generate colors, run: + +``` +python rnn_colorbot.py +``` + +This example shows how to: + 1. read, process, (one-hot) encode, and pad text data via the + Datasets API; + 2. build a trainable model; + 3. implement a multi-layer RNN using Python control flow + constructs (e.g., a for loop); + 4. train a model using an iterative gradient-based method; and + 5. log training and evaluation loss for consumption by TensorBoard + (to view summaries, use: tensorboard --log_dir=/summaries). + +The data used in this example is licensed under the Creative Commons +Attribution-ShareAlike License and is available at + https://en.wikipedia.org/wiki/List_of_colors:_A-F + https://en.wikipedia.org/wiki/List_of_colors:_G-M + https://en.wikipedia.org/wiki/List_of_colors:_N-Z + +This example was adapted from + https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py new file mode 100644 index 0000000000000000000000000000000000000000..609cbd28772c3ae8da70648ca5b1b264a8a255e2 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot.py @@ -0,0 +1,338 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""TensorFlow Eager Execution Example: RNN Colorbot. + +This example builds, trains, and evaluates a multi-layer RNN that can be +run with eager execution enabled. The RNN is trained to map color names to +their RGB values: it takes as input a one-hot encoded character sequence and +outputs a three-tuple (R, G, B) (scaled by 1/255). + +For example, say we'd like the RNN Colorbot to generate the RGB values for the +color white. To represent our query in a form that the Colorbot could +understand, we would create a sequence of five 256-long vectors encoding the +ASCII values of the characters in "white". The first vector in our sequence +would be 0 everywhere except for the ord("w")-th position, where it would be +1, the second vector would be 0 everywhere except for the +ord("h")-th position, where it would be 1, and similarly for the remaining three +vectors. We refer to such indicator vectors as "one-hot encodings" of +characters. After consuming these vectors, a well-trained Colorbot would output +the three tuple (1, 1, 1), since the RGB values for white are (255, 255, 255). +We are of course free to ask the colorbot to generate colors for any string we'd +like, such as "steel gray," "tensorflow orange," or "green apple," though +your mileage may vary as your queries increase in creativity. + +This example shows how to: + 1. read, process, (one-hot) encode, and pad text data via the + Datasets API; + 2. build a trainable model; + 3. implement a multi-layer RNN using Python control flow + constructs (e.g., a for loop); + 4. train a model using an iterative gradient-based method; and + +The data used in this example is licensed under the Creative Commons +Attribution-ShareAlike License and is available at + https://en.wikipedia.org/wiki/List_of_colors:_A-F + https://en.wikipedia.org/wiki/List_of_colors:_G-M + https://en.wikipedia.org/wiki/List_of_colors:_N-Z + +This example was adapted from + https://github.com/random-forests/tensorflow-workshop/tree/master/extras/colorbot +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import functools +import os +import sys +import time + +import six +import tensorflow as tf + +from tensorflow.contrib.eager.python import tfe +from tensorflow.python.eager import context + +try: + import matplotlib.pyplot as plt # pylint: disable=g-import-not-at-top + HAS_MATPLOTLIB = True +except ImportError: + HAS_MATPLOTLIB = False + + +def parse(line): + """Parse a line from the colors dataset.""" + + # Each line of the dataset is comma-separated and formatted as + # color_name, r, g, b + # so `items` is a list [color_name, r, g, b]. + items = tf.string_split([line], ",").values + rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255. + # Represent the color name as a one-hot encoded character sequence. + color_name = items[0] + chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256) + # The sequence length is needed by our RNN. + length = tf.cast(tf.shape(chars)[0], dtype=tf.int64) + return rgb, chars, length + + +def load_dataset(data_dir, url, batch_size): + """Loads the colors data at path into a PaddedDataset.""" + + # Downloads data at url into data_dir/basename(url). The dataset has a header + # row (color_name, r, g, b) followed by comma-separated lines. + path = tf.contrib.learn.datasets.base.maybe_download( + os.path.basename(url), data_dir, url) + + # This chain of commands loads our data by: + # 1. skipping the header; (.skip(1)) + # 2. parsing the subsequent lines; (.map(parse)) + # 3. shuffling the data; (.shuffle(...)) + # 3. grouping the data into padded batches (.padded_batch(...)). + dataset = tf.data.TextLineDataset(path).skip(1).map(parse).shuffle( + buffer_size=10000).padded_batch( + batch_size, padded_shapes=([None], [None, None], [])) + return dataset + + +# pylint: disable=not-callable +class RNNColorbot(tfe.Network): + """Multi-layer (LSTM) RNN that regresses on real-valued vector labels. + """ + + def __init__(self, rnn_cell_sizes, label_dimension, keep_prob): + """Constructs an RNNColorbot. + + Args: + rnn_cell_sizes: list of integers denoting the size of each LSTM cell in + the RNN; rnn_cell_sizes[i] is the size of the i-th layer cell + label_dimension: the length of the labels on which to regress + keep_prob: (1 - dropout probability); dropout is applied to the outputs of + each LSTM layer + """ + super(RNNColorbot, self).__init__(name="") + self.label_dimension = label_dimension + self.keep_prob = keep_prob + + # Note the calls to `track_layer` below; these calls register the layers as + # network components that house trainable variables. + self.cells = [ + self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(size)) + for size in rnn_cell_sizes + ] + self.relu = self.track_layer( + tf.layers.Dense(label_dimension, activation=tf.nn.relu, name="relu")) + + def call(self, chars, sequence_length, training=False): + """Implements the RNN logic and prediction generation. + + Args: + chars: a Tensor of dimension [batch_size, time_steps, 256] holding a + batch of one-hot encoded color names + sequence_length: a Tensor of dimension [batch_size] holding the length + of each character sequence (i.e., color name) + training: whether the invocation is happening during training + + Returns: + A tensor of dimension [batch_size, label_dimension] that is produced by + passing chars through a multi-layer RNN and applying a ReLU to the final + hidden state. + """ + # Transpose the first and second dimensions so that chars is of shape + # [time_steps, batch_size, dimension]. + chars = tf.transpose(chars, [1, 0, 2]) + # The outer loop cycles through the layers of the RNN; the inner loop + # executes the time steps for a particular layer. + batch_size = int(chars.shape[1]) + for l in range(len(self.cells)): + cell = self.cells[l] + outputs = [] + state = cell.zero_state(batch_size, tf.float32) + # Unstack the inputs to obtain a list of batches, one for each time step. + chars = tf.unstack(chars, axis=0) + for ch in chars: + output, state = cell(ch, state) + outputs.append(output) + # The outputs of this layer are the inputs of the subsequent layer. + chars = tf.stack(outputs, axis=0) + if training: + chars = tf.nn.dropout(chars, self.keep_prob) + # Extract the correct output (i.e., hidden state) for each example. All the + # character sequences in this batch were padded to the same fixed length so + # that they could be easily fed through the above RNN loop. The + # `sequence_length` vector tells us the true lengths of the character + # sequences, letting us obtain for each sequence the hidden state that was + # generated by its non-padding characters. + batch_range = [i for i in range(batch_size)] + indices = tf.stack([sequence_length - 1, batch_range], axis=1) + hidden_states = tf.gather_nd(chars, indices) + return self.relu(hidden_states) + + +def loss(labels, predictions): + """Computes mean squared loss.""" + return tf.reduce_mean(tf.square(predictions - labels)) + + +def test(model, eval_data): + """Computes the average loss on eval_data, which should be a Dataset.""" + avg_loss = tfe.metrics.Mean("loss") + for (labels, chars, sequence_length) in tfe.Iterator(eval_data): + predictions = model(chars, sequence_length, training=False) + avg_loss(loss(labels, predictions)) + print("eval/loss: %.6f\n" % avg_loss.result()) + with tf.contrib.summary.always_record_summaries(): + tf.contrib.summary.scalar("loss", avg_loss.result()) + + +def train_one_epoch(model, optimizer, train_data, log_interval=10): + """Trains model on train_data using optimizer.""" + + tf.train.get_or_create_global_step() + + def model_loss(labels, chars, sequence_length): + predictions = model(chars, sequence_length, training=True) + loss_value = loss(labels, predictions) + tf.contrib.summary.scalar("loss", loss_value) + return loss_value + + for (batch, (labels, chars, sequence_length)) in enumerate( + tfe.Iterator(train_data)): + with tf.contrib.summary.record_summaries_every_n_global_steps(log_interval): + batch_model_loss = functools.partial(model_loss, labels, chars, + sequence_length) + optimizer.minimize( + batch_model_loss, global_step=tf.train.get_global_step()) + if log_interval and batch % log_interval == 0: + print("train/batch #%d\tloss: %.6f" % (batch, batch_model_loss())) + + +SOURCE_TRAIN_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/train.csv" +SOURCE_TEST_URL = "https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/extras/colorbot/data/test.csv" + + +def main(_): + data_dir = os.path.join(FLAGS.dir, "data") + train_data = load_dataset( + data_dir=data_dir, url=SOURCE_TRAIN_URL, batch_size=FLAGS.batch_size) + eval_data = load_dataset( + data_dir=data_dir, url=SOURCE_TEST_URL, batch_size=FLAGS.batch_size) + + model = RNNColorbot( + rnn_cell_sizes=FLAGS.rnn_cell_sizes, + label_dimension=3, + keep_prob=FLAGS.keep_probability) + optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) + + if FLAGS.no_gpu or tfe.num_gpus() <= 0: + print(tfe.num_gpus()) + device = "/cpu:0" + else: + device = "/gpu:0" + print("Using device %s." % device) + + log_dir = os.path.join(FLAGS.dir, "summaries") + tf.gfile.MakeDirs(log_dir) + train_summary_writer = tf.contrib.summary.create_summary_file_writer( + os.path.join(log_dir, "train"), flush_millis=10000) + test_summary_writer = tf.contrib.summary.create_summary_file_writer( + os.path.join(log_dir, "eval"), flush_millis=10000, name="eval") + + with tf.device(device): + for epoch in range(FLAGS.num_epochs): + start = time.time() + with train_summary_writer.as_default(): + train_one_epoch(model, optimizer, train_data, FLAGS.log_interval) + end = time.time() + print("train/time for epoch #%d: %.2f" % (epoch, end - start)) + with test_summary_writer.as_default(): + test(model, eval_data) + + print("Colorbot is ready to generate colors!") + while True: + try: + color_name = six.moves.input( + "Give me a color name (or press enter to exit): ") + except EOFError: + return + + if not color_name: + return + + _, chars, length = parse(color_name) + with tf.device(device): + (chars, length) = (tf.identity(chars), tf.identity(length)) + chars = tf.expand_dims(chars, 0) + length = tf.expand_dims(length, 0) + preds = tf.unstack(model(chars, length, training=False)[0]) + + # Predictions cannot be negative, as they are generated by a ReLU layer; + # they may, however, be greater than 1. + clipped_preds = tuple(min(float(p), 1.0) for p in preds) + rgb = tuple(int(p * 255) for p in clipped_preds) + print("rgb:", rgb) + data = [[clipped_preds]] + if HAS_MATPLOTLIB: + plt.imshow(data) + plt.title(color_name) + plt.show() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--dir", + type=str, + default="/tmp/rnn_colorbot/", + help="Directory to download data files and save logs.") + parser.add_argument( + "--log_interval", + type=int, + default=10, + metavar="N", + help="Log training loss every log_interval batches.") + parser.add_argument( + "--num_epochs", type=int, default=20, help="Number of epochs to train.") + parser.add_argument( + "--rnn_cell_sizes", + type=int, + nargs="+", + default=[256, 128], + help="List of sizes for each layer of the RNN.") + parser.add_argument( + "--batch_size", + type=int, + default=64, + help="Batch size for training and eval.") + parser.add_argument( + "--keep_probability", + type=float, + default=0.5, + help="Keep probability for dropout between layers.") + parser.add_argument( + "--learning_rate", + type=float, + default=0.01, + help="Learning rate to be used during training.") + parser.add_argument( + "--no_gpu", + action="store_true", + default=False, + help="Disables GPU usage even if a GPU is available.") + + FLAGS, unparsed = parser.parse_known_args() + tfe.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py new file mode 100644 index 0000000000000000000000000000000000000000..75b342ba78bd5de5c2827296f6fba01ffa86d560 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_colorbot/rnn_colorbot_test.py @@ -0,0 +1,71 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.eager.python import tfe +from tensorflow.contrib.eager.python.examples.rnn_colorbot import rnn_colorbot + + +LABEL_DIMENSION = 5 + + +def device(): + return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" + + +def random_dataset(): + batch_size = 64 + time_steps = 10 + alphabet = 50 + chars = tf.one_hot( + tf.random_uniform( + [batch_size, time_steps], minval=0, maxval=alphabet, dtype=tf.int32), + alphabet) + sequence_length = tf.constant( + [time_steps for _ in range(batch_size)], dtype=tf.int64) + labels = tf.random_normal([batch_size, LABEL_DIMENSION]) + return tf.data.Dataset.from_tensors((labels, chars, sequence_length)) + + +class RNNColorbotTest(tf.test.TestCase): + + def testTrainOneEpoch(self): + model = rnn_colorbot.RNNColorbot( + rnn_cell_sizes=[256, 128, 64], + label_dimension=LABEL_DIMENSION, + keep_prob=1.0) + optimizer = tf.train.AdamOptimizer(learning_rate=.01) + dataset = random_dataset() + with tf.device(device()): + rnn_colorbot.train_one_epoch(model, optimizer, dataset) + + def testTest(self): + model = rnn_colorbot.RNNColorbot( + rnn_cell_sizes=[256], + label_dimension=LABEL_DIMENSION, + keep_prob=1.0) + dataset = random_dataset() + with tf.device(device()): + rnn_colorbot.test(model, dataset) + + +if __name__ == "__main__": + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..db2587bf2cb548ae37e58597691e96ae2c2e8177 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/BUILD @@ -0,0 +1,35 @@ +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//tensorflow:internal"]) + +load("//tensorflow:tensorflow.bzl", "cuda_py_test") + +py_binary( + name = "rnn_ptb", + srcs = ["rnn_ptb.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/eager/python:tfe", + ], +) + +cuda_py_test( + name = "rnn_ptb_test", + srcs = ["rnn_ptb_test.py"], + additional_deps = [ + ":rnn_ptb", + "//tensorflow/contrib/eager/python:tfe", + "//tensorflow:tensorflow_py", + ], +) + +cuda_py_test( + name = "rnn_ptb_graph_test", + srcs = ["rnn_ptb_graph_test.py"], + additional_deps = [ + ":rnn_ptb", + "//third_party/py/numpy", + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md new file mode 100644 index 0000000000000000000000000000000000000000..743ebb68ee5bba5635899267cc4839828f7e4e2f --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/README.md @@ -0,0 +1,54 @@ +Recurrent Neural Network model. + +Implements a language modeling network described in +https://www.tensorflow.org/tutorials/recurrent +that is compatible with (and idiomatic for) eager execution. + +To run: + +- Download and extract the Penn Treebank dataset from + http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + + ```sh + tar xvzf simple-examples.tgz -C /tmp + ``` + +- Run: `python rnn_ptb.py --data-dir=/tmp/simple-examples/data` + + +Benchmarks (using synthetic data): + +``` +# Using eager execution +python rnn_ptb_test.py --benchmarks=. + +# Using graph execution +python rnn_ptb_graph_test.py --benchmarks=. +``` + +The above uses the model definition included with the TensorFlow pip +package. To build (and run benchmarks) from source: + + +``` +# Using eager execution +bazel run -c opt --config=cuda :rnn_ptb_test -- --benchmarks=. + +# Using graph execution +bazel run -c opt --config=cuda :rnn_ptb_graph_test -- --benchmarks=. +``` + +(Or remove the `--config=cuda` flag for running on CPU instead of GPU). + +On October 31, 2017, the benchmarks demostrated slightly better performance +(3-6%) for graph execution over eager execution for this particular model when +using a single NVIDIA Titan X (Pascal) GPU on a host with an Intel Xeon E5-1650 +CPU @ 3.50GHz and a batch size of 32. + +| Benchmark name | examples/second | +| ------------------------------------ | --------------- | +| eager_cudnn_train_large_gpu_batch_20 | 938 | +| graph_cudnn_train_large_gpu_batch_20 | 971 | +| eager_cudnn_train_small_gpu_batch_20 | 2433 | +| graph_cudnn_train_small_gpu_batch_20 | 2585 | + diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py new file mode 100644 index 0000000000000000000000000000000000000000..30bb3c8ad33d38453bd96a76c7770071e24bb034 --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb.py @@ -0,0 +1,359 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Penn Treebank RNN model definition compatible with eager execution. + +Model similar to +https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb + +Usage: python ./rnn_ptb.py --data-path= + +Penn Treebank (PTB) dataset from: +http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz +""" +import argparse +import os +import sys +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.cudnn_rnn.python.layers import cudnn_rnn +from tensorflow.contrib.eager.python import tfe + + +class RNN(tfe.Network): + """A static RNN. + + Similar to tf.nn.static_rnn, implemented as a tf.layer.Layer. + """ + + def __init__(self, hidden_dim, num_layers, keep_ratio): + super(RNN, self).__init__() + self.keep_ratio = keep_ratio + for _ in range(num_layers): + self.track_layer(tf.nn.rnn_cell.BasicLSTMCell(num_units=hidden_dim)) + + def call(self, input_seq, training): + batch_size = int(input_seq.shape[1]) + for c in self.layers: + state = c.zero_state(batch_size, tf.float32) + outputs = [] + input_seq = tf.unstack(input_seq, num=int(input_seq.shape[0]), axis=0) + for inp in input_seq: + output, state = c(inp, state) + outputs.append(output) + + input_seq = tf.stack(outputs, axis=0) + if training: + input_seq = tf.nn.dropout(input_seq, self.keep_ratio) + return input_seq, None + + +class Embedding(tf.layers.Layer): + """An Embedding layer.""" + + def __init__(self, vocab_size, embedding_dim, **kwargs): + super(Embedding, self).__init__(**kwargs) + self.vocab_size = vocab_size + self.embedding_dim = embedding_dim + + def build(self, _): + self.embedding = self.add_variable( + "embedding_kernel", + shape=[self.vocab_size, self.embedding_dim], + dtype=tf.float32, + initializer=tf.random_uniform_initializer(-0.1, 0.1), + trainable=True) + + def call(self, x): + return tf.nn.embedding_lookup(self.embedding, x) + + +class PTBModel(tfe.Network): + """LSTM for word language modelling. + + Model described in: + (Zaremba, et. al.) Recurrent Neural Network Regularization + http://arxiv.org/abs/1409.2329 + + See also: + https://github.com/tensorflow/models/tree/master/tutorials/rnn/ptb + """ + + def __init__(self, + vocab_size, + embedding_dim, + hidden_dim, + num_layers, + dropout_ratio, + use_cudnn_rnn=True): + super(PTBModel, self).__init__() + + self.keep_ratio = 1 - dropout_ratio + self.use_cudnn_rnn = use_cudnn_rnn + self.embedding = self.track_layer(Embedding(vocab_size, embedding_dim)) + + if self.use_cudnn_rnn: + self.rnn = cudnn_rnn.CudnnLSTM( + num_layers, hidden_dim, dropout=dropout_ratio) + else: + self.rnn = RNN(hidden_dim, num_layers, self.keep_ratio) + self.track_layer(self.rnn) + + self.linear = self.track_layer( + tf.layers.Dense( + vocab_size, + kernel_initializer=tf.random_uniform_initializer(-0.1, 0.1))) + self._output_shape = [-1, embedding_dim] + + def call(self, input_seq, training): + """Run the forward pass of PTBModel. + + Args: + input_seq: [length, batch] shape int64 tensor. + training: Is this a training call. + Returns: + outputs tensors of inference. + """ + y = self.embedding(input_seq) + if training: + y = tf.nn.dropout(y, self.keep_ratio) + y, _ = self.rnn(y, training=training) + return self.linear(tf.reshape(y, self._output_shape)) + + +def clip_gradients(grads_and_vars, clip_ratio): + gradients, variables = zip(*grads_and_vars) + clipped, _ = tf.clip_by_global_norm(gradients, clip_ratio) + return zip(clipped, variables) + + +def loss_fn(model, inputs, targets, training): + labels = tf.reshape(targets, [-1]) + outputs = model(inputs, training) + return tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=outputs)) + + +def _divide_into_batches(data, batch_size): + """Convert a sequence to a batch of sequences.""" + nbatch = data.shape[0] // batch_size + data = data[:nbatch * batch_size] + data = data.reshape(batch_size, -1).transpose() + return data + + +def _get_batch(data, i, seq_len): + slen = min(seq_len, data.shape[0] - 1 - i) + inputs = data[i:i + slen, :] + target = data[i + 1:i + 1 + slen, :] + return tf.constant(inputs), tf.constant(target) + + +def evaluate(model, data): + """evaluate an epoch.""" + total_loss = 0.0 + total_batches = 0 + start = time.time() + for _, i in enumerate(range(0, data.shape[0] - 1, FLAGS.seq_len)): + inp, target = _get_batch(data, i, FLAGS.seq_len) + loss = loss_fn(model, inp, target, training=False) + total_loss += loss.numpy() + total_batches += 1 + time_in_ms = (time.time() - start) * 1000 + sys.stderr.write("eval loss %.2f (eval took %d ms)\n" % + (total_loss / total_batches, time_in_ms)) + return total_loss + + +def train(model, optimizer, train_data, sequence_length, clip_ratio): + """training an epoch.""" + + def model_loss(inputs, targets): + return loss_fn(model, inputs, targets, training=True) + + grads = tfe.implicit_gradients(model_loss) + + total_time = 0 + for batch, i in enumerate(range(0, train_data.shape[0] - 1, sequence_length)): + train_seq, train_target = _get_batch(train_data, i, sequence_length) + start = time.time() + optimizer.apply_gradients( + clip_gradients(grads(train_seq, train_target), clip_ratio)) + total_time += (time.time() - start) + if batch % 10 == 0: + time_in_ms = (total_time * 1000) / (batch + 1) + sys.stderr.write("batch %d: training loss %.2f, avg step time %d ms\n" % + (batch, model_loss(train_seq, train_target).numpy(), + time_in_ms)) + + +class Datasets(object): + """Processed form of the Penn Treebank dataset.""" + + def __init__(self, path): + """Load the Penn Treebank dataset. + + Args: + path: Path to the data/ directory of the dataset from from Tomas Mikolov's + webpage - http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz + """ + + self.word2idx = {} # string -> integer id + self.idx2word = [] # integer id -> word string + # Files represented as a list of integer ids (as opposed to list of string + # words). + self.train = self.tokenize(os.path.join(path, "ptb.train.txt")) + self.valid = self.tokenize(os.path.join(path, "ptb.valid.txt")) + + def vocab_size(self): + return len(self.idx2word) + + def add(self, word): + if word not in self.word2idx: + self.idx2word.append(word) + self.word2idx[word] = len(self.idx2word) - 1 + + def tokenize(self, path): + """Read text file in path and return a list of integer token ids.""" + tokens = 0 + with tf.gfile.Open(path, "r") as f: + for line in f: + words = line.split() + [""] + tokens += len(words) + for word in words: + self.add(word) + + # Tokenize file content + with tf.gfile.Open(path, "r") as f: + ids = np.zeros(tokens).astype(np.int64) + token = 0 + for line in f: + words = line.split() + [""] + for word in words: + ids[token] = self.word2idx[word] + token += 1 + + return ids + + +def small_model(use_cudnn_rnn): + """Returns a PTBModel with a 'small' configuration.""" + return PTBModel( + vocab_size=10000, + embedding_dim=200, + hidden_dim=200, + num_layers=2, + dropout_ratio=0., + use_cudnn_rnn=use_cudnn_rnn) + + +def large_model(use_cudnn_rnn): + """Returns a PTBModel with a 'large' configuration.""" + return PTBModel( + vocab_size=10000, + embedding_dim=650, + hidden_dim=650, + num_layers=2, + dropout_ratio=0.5, + use_cudnn_rnn=use_cudnn_rnn) + + +def test_model(use_cudnn_rnn): + """Returns a tiny PTBModel for unit tests.""" + return PTBModel( + vocab_size=100, + embedding_dim=20, + hidden_dim=20, + num_layers=2, + dropout_ratio=0., + use_cudnn_rnn=use_cudnn_rnn) + + +def main(_): + tfe.enable_eager_execution() + + if not FLAGS.data_path: + raise ValueError("Must specify --data-path") + corpus = Datasets(FLAGS.data_path) + train_data = _divide_into_batches(corpus.train, FLAGS.batch_size) + eval_data = _divide_into_batches(corpus.valid, 10) + + have_gpu = tfe.num_gpus() > 0 + use_cudnn_rnn = not FLAGS.no_use_cudnn_rnn and have_gpu + + with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(FLAGS.logdir)): + with tf.device("/device:GPU:0" if have_gpu else None): + # Make learning_rate a Variable so it can be included in the checkpoint + # and we can resume training with the last saved learning_rate. + learning_rate = tfe.Variable(20.0, name="learning_rate") + sys.stderr.write("learning_rate=%f\n" % learning_rate.numpy()) + model = PTBModel(corpus.vocab_size(), FLAGS.embedding_dim, + FLAGS.hidden_dim, FLAGS.num_layers, FLAGS.dropout, + use_cudnn_rnn) + optimizer = tf.train.GradientDescentOptimizer(learning_rate) + + best_loss = None + for _ in range(FLAGS.epoch): + train(model, optimizer, train_data, FLAGS.seq_len, FLAGS.clip) + eval_loss = evaluate(model, eval_data) + if not best_loss or eval_loss < best_loss: + if FLAGS.logdir: + tfe.Saver(model.trainable_weights + [learning_rate]).save( + os.path.join(FLAGS.logdir, "ckpt")) + best_loss = eval_loss + else: + learning_rate.assign(learning_rate / 4.0) + sys.stderr.write("eval_loss did not reduce in this epoch, " + "changing learning rate to %f for the next epoch\n" % + learning_rate.numpy()) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-path", + type=str, + default="", + help="Data directory of the Penn Treebank dataset from " + "http://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz") + parser.add_argument( + "--logdir", type=str, default="", help="Directory for checkpoint.") + parser.add_argument( + "--epoch", type=int, default=20, help="Number of epoches.") + parser.add_argument("--batch-size", type=int, default=20, help="Batch size.") + parser.add_argument( + "--seq-len", type=int, default=35, help="Sequence length.") + parser.add_argument( + "--embedding-dim", type=int, default=200, help="Embedding dimension.") + parser.add_argument( + "--hidden-dim", type=int, default=200, help="Hidden layer dimension.") + parser.add_argument( + "--num-layers", type=int, default=2, help="Number of RNN layers.") + parser.add_argument( + "--dropout", type=float, default=0.2, help="Drop out ratio.") + parser.add_argument( + "--clip", type=float, default=0.25, help="Gradient clipping ratio.") + parser.add_argument( + "--no-use-cudnn-rnn", + action="store_true", + default=False, + help="Disable the fast CuDNN RNN (when no gpu)") + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..63b5c4c54d13e9c2448ec1f572ca1389f2443bef --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_graph_test.py @@ -0,0 +1,164 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PTBModel used for graph construction.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb + + +class PTBTest(tf.test.TestCase): + + def testTrain(self): + batch_size = 20 + sequence_length = 35 + with tf.Graph().as_default(), tf.device(tf.test.gpu_device_name()): + inputs_ph = tf.placeholder(tf.int64, [sequence_length, batch_size], + "inputs") + labels_ph = tf.placeholder(tf.int64, [sequence_length, batch_size], + "labels") + + inputs = np.ones(inputs_ph.shape.as_list(), dtype=np.int64) + labels = np.ones(labels_ph.shape.as_list(), dtype=np.int64) + + model = rnn_ptb.test_model(tf.test.is_gpu_available()) + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + loss = rnn_ptb.loss_fn(model, inputs_ph, labels_ph, training=True) + grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25) + train_op = optimizer.apply_gradients(grads) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(train_op, feed_dict={inputs_ph: inputs, labels_ph: labels}) + sess.run( + [train_op, loss], feed_dict={ + inputs_ph: inputs, + labels_ph: labels + }) + + +class PTBBenchmark(tf.test.Benchmark): + + BATCH_SIZE = 20 + SEQ_LEN = 35 + + def _report(self, label, start, num_iters, device, batch_size): + wall_time = (time.time() - start) / num_iters + dev = "cpu" if "cpu" in device.lower() else "gpu" + name = "%s_%s_batch_%d" % (label, dev, batch_size) + examples_per_sec = batch_size / wall_time + self.report_benchmark( + iters=num_iters, + wall_time=wall_time, + name=name, + extras={ + "examples_per_sec": examples_per_sec + }) + + def _benchmark_apply(self, label, model): + num_iters = 100 + num_warmup = 10 + dataset = tf.data.Dataset.from_tensors( + tf.ones( + [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], + dtype=tf.int64)).repeat(num_iters + num_warmup) + inputs = dataset.make_one_shot_iterator().get_next() + + with tf.device(tf.test.gpu_device_name()): + outputs = model(inputs, training=True) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(num_warmup): + sess.run(outputs) + gc.collect() + + start = time.time() + for _ in range(num_iters): + sess.run(outputs) + self._report(label, start, num_iters, + tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE) + + def benchmark_apply_small(self): + self._benchmark_apply("graph_apply_small", rnn_ptb.small_model(False)) + + def benchmark_apply_large(self): + self._benchmark_apply("graph_apply_large", rnn_ptb.large_model(False)) + + def benchmark_cudnn_apply_small(self): + if not tf.test.is_gpu_available(): + return + self._benchmark_apply("graph_cudnn_apply_small", rnn_ptb.small_model(True)) + + def benchmark_cudnn_apply_large(self): + if not tf.test.is_gpu_available(): + return + self._benchmark_apply("graph_cudnn_apply_large", rnn_ptb.large_model(True)) + + def _benchmark_train(self, label, model): + num_iters = 100 + num_warmup = 10 + dataset = tf.data.Dataset.from_tensors( + tf.ones( + [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], + dtype=tf.int64)).repeat(num_iters + num_warmup) + # inputs and labels have the same shape + dataset = tf.data.Dataset.zip((dataset, dataset)) + (inputs, labels) = dataset.make_one_shot_iterator().get_next() + + with tf.device(tf.test.gpu_device_name()): + optimizer = tf.train.GradientDescentOptimizer(learning_rate=1.0) + loss = rnn_ptb.loss_fn(model, inputs, labels, training=True) + grads = rnn_ptb.clip_gradients(optimizer.compute_gradients(loss), 0.25) + train_op = optimizer.apply_gradients(grads) + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + for _ in range(num_warmup): + sess.run(train_op) + gc.collect() + start = time.time() + for _ in range(num_iters): + sess.run(train_op) + self._report(label, start, num_iters, + tf.test.gpu_device_name(), PTBBenchmark.BATCH_SIZE) + + def benchmark_train_small(self): + self._benchmark_train("graph_train_small", rnn_ptb.small_model(False)) + + def benchmark_train_large(self): + self._benchmark_train("graph_train_large", rnn_ptb.large_model(False)) + + def benchmark_cudnn_train_small(self): + if not tf.test.is_gpu_available(): + return + self._benchmark_train("graph_cudnn_train_small", rnn_ptb.small_model(True)) + + def benchmark_cudnn_train_large(self): + if not tf.test.is_gpu_available(): + return + self._benchmark_train("graph_cudnn_train_large", rnn_ptb.large_model(True)) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_test.py b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b279bc4a7c3510b6a59bc618b531141beebdfaab --- /dev/null +++ b/tensorflow/contrib/eager/python/examples/rnn_ptb/rnn_ptb_test.py @@ -0,0 +1,154 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for PTBModel with eager execution enabled.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.eager.python import tfe +from tensorflow.contrib.eager.python.examples.rnn_ptb import rnn_ptb + + +def device(): + return "/device:GPU:0" if tfe.num_gpus() else "/device:CPU:0" + + +class PTBTest(tf.test.TestCase): + + def testTrain(self): + model = rnn_ptb.test_model(tfe.num_gpus() > 0) + sequence_length = 35 + data = np.ones([4 * sequence_length, 20], dtype=np.int64) + with tf.device(device()): + optimizer = tf.train.GradientDescentOptimizer(1.0) + # Train two epochs + rnn_ptb.train(model, optimizer, data, sequence_length, 0.25) + rnn_ptb.train(model, optimizer, data, sequence_length, 0.25) + + def testApply(self): + model = rnn_ptb.test_model(tfe.num_gpus() > 0) + with tf.device(device()): + model(tf.ones([35, 20], dtype=tf.int64), training=False) + + +def force_gpu_sync(): + if tfe.num_gpus(): + tf.constant(1).gpu().cpu() + + +class PTBBenchmark(tf.test.Benchmark): + + BATCH_SIZE = 20 + SEQ_LEN = 35 + + def _report(self, label, start, num_iters, dev, batch_size): + wall_time = (time.time() - start) / num_iters + dev = "cpu" if "cpu" in dev.lower() else "gpu" + name = "%s_%s_batch_%d" % (label, dev, batch_size) + examples_per_sec = batch_size / wall_time + self.report_benchmark( + iters=num_iters, + wall_time=wall_time, + name=name, + extras={ + "examples_per_sec": examples_per_sec + }) + + def _benchmark_apply(self, label, model): + with tf.device(device()): + sequence_batch = tf.ones( + [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64) + + for _ in range(10): # Warmup + model(sequence_batch, training=False).cpu() + gc.collect() + + start = time.time() + iters = 100 + for _ in range(iters): + model(sequence_batch, training=False).cpu() + self._report(label, start, iters, device(), int(sequence_batch.shape[1])) + + def benchmark_apply_small(self): + self._benchmark_apply("eager_apply_small", rnn_ptb.small_model(False)) + + def benchmark_apply_large(self): + self._benchmark_apply("eager_apply_large", rnn_ptb.large_model(False)) + + def benchmark_cudnn_apply_small(self): + if not tfe.num_gpus(): + return + self._benchmark_apply("eager_cudnn_apply_small", rnn_ptb.small_model(True)) + + def benchmark_cudnn_apply_large(self): + if not tfe.num_gpus(): + return + self._benchmark_apply("eager_cudnn_apply_large", rnn_ptb.large_model(True)) + + def _benchmark_train(self, label, model): + with tf.device(device()): + optimizer = tf.train.GradientDescentOptimizer(1.) + + def model_loss(inputs, targets): + return rnn_ptb.loss_fn(model, inputs, targets, training=True) + + grads = tfe.implicit_gradients(model_loss) + + sequence_batch = tf.ones( + [PTBBenchmark.SEQ_LEN, PTBBenchmark.BATCH_SIZE], dtype=tf.int64) + + def step(): + optimizer.apply_gradients( + rnn_ptb.clip_gradients(grads(sequence_batch, sequence_batch), 0.25)) + + for _ in range(10): # Warmup + step() + force_gpu_sync() + gc.collect() + + start = time.time() + iters = 100 + for _ in range(iters): + step() + force_gpu_sync() + self._report(label, start, iters, device(), int(sequence_batch.shape[1])) + + def benchmark_train_small(self): + self._benchmark_train("eager_train_small", rnn_ptb.small_model(False)) + + def benchmark_train_large(self): + self._benchmark_train("eager_train_large", rnn_ptb.large_model(False)) + + def benchmark_cudnn_train_small(self): + if not tfe.num_gpus(): + return + self._benchmark_train("eager_cudnn_train_small", rnn_ptb.small_model(True)) + + def benchmark_cudnn_train_large(self): + if not tfe.num_gpus(): + return + self._benchmark_train("eager_cudnn_train_large", rnn_ptb.large_model(True)) + + +if __name__ == "__main__": + tfe.enable_eager_execution() + tf.test.main() diff --git a/tensorflow/contrib/eager/python/g3doc/guide.md b/tensorflow/contrib/eager/python/g3doc/guide.md new file mode 100644 index 0000000000000000000000000000000000000000..147b7047f42b7ccba5829b61370e82e217ce5838 --- /dev/null +++ b/tensorflow/contrib/eager/python/g3doc/guide.md @@ -0,0 +1,898 @@ +# TensorFlow Eager Execution + +## What is this? + +Eager execution is a feature that makes TensorFlow execute operations +immediately: concrete values are returned, instead of a computational graph to +be executed later. + +As a result, enabling eager execution provides: + +- A [NumPy](http://www.numpy.org/)-like library for numerical computation with + support for GPU acceleration and automatic differentiation. +- A flexible platform for machine learning research and experimentation. + +Eager execution is under active development. This guide walks through an +alpha/preview release. In particular, not all TensorFlow APIs currently work +with eager execution enabled, and some models may be slow to execute, compared +to models defined without using eager execution. + +## Installation + +Eager execution is **not** included in the latest release (version 1.4) of +TensorFlow. To use it, you will need to [build TensorFlow from +source](https://www.tensorflow.org/install/install_sources) or install the +nightly builds. + +For example, the nightly builds can be installed using `pip`: + +- `pip install tf-nightly` (for CPU-only TensorFlow) +- `pip install tf-nightly-gpu` (for GPU-enabled TensorFlow) + +Or using `docker`, with [Jupyter Notebook](http://jupyter.org/) support: + +```sh +# For CPU-only TensorFlow +docker pull tensorflow/tensorflow:nightly +docker run -it -p 8888:8888 tensorflow/tensorflow:nightly + +# For GPU-enabled TensorFlow: +# (Requires https://github.com/NVIDIA/nvidia-docker) +nvidia-docker pull tensorflow/tensorflow:nightly-gpu +nvidia-docker run -it -p 8888:8888 tensorflow/tensorflow:nightly-gpu +``` + +## Getting Started + +With TensorFlow installed, eager execution is enabled via a single call: + +```python +import tensorflow as tf + +import tensorflow.contrib.eager as tfe + +tfe.enable_eager_execution() +``` + +Enabling eager execution changes how TensorFlow functions behave (in particular, +`Tensor` objects will reference concrete values instead of being symbolic +handles to nodes in a computational graph). As a result, eager execution should +be enabled at the beginning of a program and cannot be disabled afterwards in +the same program. + +Code examples in the rest of this guide assume that eager execution has been +enabled. + +## A library for numerical computation + +A significant fraction of the [TensorFlow +API](https://www.tensorflow.org/api_docs/python/) consists of numerical +operations: +[arithmetic operations](https://www.tensorflow.org/api_guides/python/math_ops#Arithmetic_Operators), +[matrix operations](https://www.tensorflow.org/api_guides/python/math_ops#Matrix_Math_Functions), +[linear algebra operations](https://www.tensorflow.org/versions/master/api_docs/python/tf/linalg), +etc. + +With eager execution enabled, these operations consume and return +multi-dimensional arrays as `Tensor` objects, similar to NumPy +[`ndarray`s](https://docs.scipy.org/doc/numpy-1.13.0/reference/generated/numpy.ndarray.html). +For example: + +```python +# Multiply two 2x2 matrices +x = tf.matmul([[1, 2], + [3, 4]], + [[4, 5], + [6, 7]]) +# Add one to each element +# (tf.add supports broadcasting) +y = tf.add(x, 1) + +# Create a random random 5x3 matrix +z = tf.random_uniform([5, 3]) + +print(x) +print(y) +print(z) +``` + +Output: + +``` +tf.Tensor( +[[16 19] + [36 43]], shape=(2, 2), dtype=int32) +tf.Tensor( +[[17 20] + [37 44]], shape=(2, 2), dtype=int32) +tf.Tensor( +[[ 0.25058532 0.0929395 0.54113817] + [ 0.3108716 0.93350542 0.84909797] + [ 0.53081679 0.12788558 0.01767385] + [ 0.29725885 0.33540785 0.83588314] + [ 0.38877153 0.39720535 0.78914213]], shape=(5, 3), dtype=float32) +``` + +For convenience, these operations can also be triggered via operator overloading +of the `Tensor` object. For example, the `+` operator is equivalent to `tf.add`, +`-` to `tf.subtract`, `*` to `tf.multiply`, etc.: + +```python +x = (tf.ones([1], dtype=tf.float32) + 1) * 2 - 1 +print(x) +``` + +Output: + +``` +tf.Tensor([ 3.], shape=(1,), dtype=float32) +``` + +### Converting to and from NumPy + +The operations above automatically convert Python objects (like lists of +numbers) and NumPy arrays to `Tensor` objects. `Tensor` objects can also be used +as NumPy arrays by numpy operations. + +```python +import numpy as np + +x = tf.add(1, 1) # tf.Tensor with a value of 2 +y = tf.add(np.array(1), np.array(1)) # tf.Tensor with a value of 2 +z = np.multiply(x, y) # numpy.int64 with a value of 4 +``` + +Alternatively, they can be explicitly converted using +[`tf.constant`](https://www.tensorflow.org/api_docs/python/tf/constant), as +shown in the next example. + +Conversely, you can call the `numpy()` method of a `Tensor` object' to obtain +its NumPy `ndarray` value. For example: + +```python +import numpy as np + +np_x = np.array(2., dtype=np.float32) +x = tf.constant(np_x) + +py_y = 3. +y = tf.constant(py_y) + +z = x + y + 1 + +print(z) +print(z.numpy()) +``` + +Output: + +``` +tf.Tensor(6.0, shape=(), dtype=float32) +6.0 +``` + +### GPU acceleration + +Many TensorFlow operations support GPU acceleration. With eager execution +enabled, [computation is *not* automatically +offloaded](https://www.tensorflow.org/tutorials/using_gpu) to GPUs. Instead, you +must explicitly specify when GPUs should be used. + +The simplest way to do this is to enclose your computation in a `with +tf.device('/gpu:0')` block. Also of interest is the `tfe.num_gpus()` function, +which returns the number of available GPUs. + +For example, consider this snippet to measure the time to multiply two 1000x1000 +matrices on CPU: + +```python +import time + +def measure(x): + # The very first time a GPU is used by TensorFlow, it is initialized. + # So exclude the first run from timing. + tf.matmul(x, x) + + start = time.time() + for i in range(10): + tf.matmul(x, x) + end = time.time() + + return "Took %s seconds to multiply a %s matrix by itself 10 times" % (end - start, x.shape) + +# Run on CPU: +with tf.device("/cpu:0"): + print("CPU: %s" % measure(tf.random_normal([1000, 1000]))) + +# If a GPU is available, run on GPU: +if tfe.num_gpus() > 0: + with tf.device("/gpu:0"): + print("GPU: %s" % measure(tf.random_normal([1000, 1000]))) +``` + +Output (exact numbers will depend on the characteristics of the hardware): + +```python +CPU: Took 0.145531892776 seconds to multiply a (1000, 1000) matrix by itself 10 times +GPU: Took 0.000458955764771 seconds to multiply a (1000, 1000) matrix by itself 10 times +``` + +Alternatively, methods on the `Tensor` object can be used to explicitly copy the +`Tensor` to a different device. Operations are typically executed on the device +on which the inputs are placed. For example: + +```python +x = tf.random_normal([10, 10]) + +x_gpu0 = x.gpu() +x_cpu = x.cpu() + +_ = tf.matmul(x_cpu, x_cpu) # Runs on CPU +_ = tf.matmul(x_gpu0, x_gpu0) # Runs on GPU:0 + +if tfe.num_gpus() > 1: + x_gpu1 = x.gpu(1) + _ = tf.matmul(x_gpu1, x_gpu1) # Runs on GPU:1 +``` + +### Automatic Differentiation + +[Automatic +differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation) is +very useful when implementing many machine learning algorithms (e.g., +[backpropagation](https://en.wikipedia.org/wiki/Backpropagation) for training +neural networks). For this purpose, TensorFlow eager execution provides an +[autograd](https://github.com/HIPS/autograd)-style API for automatic +differentiation. Specifically, the functions: + +- `tfe.gradients_function(f)`: Returns a Python function that computes the + derivatives of the Python function `f` with respect to its arguments. `f` + must return a scalar value. When the returned function is invoked, it + returns a list of `Tensor` objects (one element for each argument of `f`). +- `tfe.value_and_gradients_function(f)`: Similar to `tfe.gradients_function`, + except that when the returned function is invoked, it returns the value of + `f` in addition to the list of derivatives of `f` with respect to its + arguments. + +These functions naturally apply to higher order differentiation as well. For +example: + +```python +def f(x): + return tf.multiply(x, x) # Or x * x +assert 9 == f(3.).numpy() + +df = tfe.gradients_function(f) +assert 6 == df(3.)[0].numpy() + +# Second order deriviative. +d2f = tfe.gradients_function(lambda x: df(x)[0]) +assert 2 == d2f(3.)[0].numpy() + +# Third order derivative. +d3f = tfe.gradients_function(lambda x : d2f(x)[0]) +assert 0 == d3f(3.)[0].numpy() +``` + +These functions can be used to train models. For example, consider the following +simple linear regression model: + +```python +def prediction(input, weight, bias): + return input * weight + bias + +# A toy dataset of points around 3 * x + 2 +NUM_EXAMPLES = 1000 +training_inputs = tf.random_normal([NUM_EXAMPLES]) +noise = tf.random_normal([NUM_EXAMPLES]) +training_outputs = training_inputs * 3 + 2 + noise + +# A loss function: Mean-squared error +def loss(weight, bias): + error = prediction(training_inputs, weight, bias) - training_outputs + return tf.reduce_mean(tf.square(error)) + +# Function that returns the the derivative of loss with respect to +# weight and bias +grad = tfe.gradients_function(loss) + +# Train for 200 steps (starting from some random choice for W and B, on the same +# batch of data). +W = 5. +B = 10. +learning_rate = 0.01 +print("Initial loss: %f" % loss(W, B).numpy()) +for i in range(200): + (dW, dB) = grad(W, B) + W -= dW * learning_rate + B -= dB * learning_rate + if i % 20 == 0: + print("Loss at step %d: %f" % (i, loss(W, B).numpy())) +print("Final loss: %f" % loss(W, B).numpy()) +print("W, B = %f, %f" % (W.numpy(), B.numpy())) +``` + +Output: (the exact numbers may vary depending on the randomness in noise) + +``` +Initial loss: 66.730003 +Loss at step 0: 64.200096 +Loss at step 20: 29.872814 +Loss at step 40: 14.233772 +Loss at step 60: 7.090570 +Loss at step 80: 3.819887 +Loss at step 100: 2.318821 +Loss at step 120: 1.628385 +Loss at step 140: 1.310142 +Loss at step 160: 1.163167 +Loss at step 180: 1.095162 +Final loss: 1.064711 +W, B = 3.094944, 2.161383 +``` + +To utilize the GPU, place the code above within a `with tf.device("/gpu:0"):` +block. (However, this particular model, with only two floating point parameters, +is unlikely to benefit from GPU acceleration.) + +### Customizing gradients + +One may want to define custom gradients for an operation, or for a function. +This may be useful for multiple reasons, including providing a more efficient +or more [numerically stable](https://en.wikipedia.org/wiki/Numerical_stability) +gradient for a sequence of operations. + +For example, consider the function `log(1 + e^x)`, which commonly occurs in the +computation of cross entropy and log likelihoods. + +```python +def log1pexp(x): +  return tf.log(1 + tf.exp(x)) +grad_log1pexp = tfe.gradients_function(log1pexp) + +# Works fine at x = 0. +assert 0.5 == float(grad_log1pexp(0.)[0]) + +# Returns a `nan` at x = 100 due to numerical instability. +import math +assert math.isnan(float(grad_log1pexp(100.)[0])) +``` + +We can define a custom gradient for the above function that analytically +simplifies the gradient expression. + +```python +@tfe.custom_gradient +def log1pexp(x): +  e = tf.exp(x) +  def grad(dy): +    return dy * (1 - 1 / (1 + e)) +  return tf.log(1 + e), grad +grad_log1pexp = tfe.gradients_function(log1pexp) + +# Works as before at x = 0. +assert 0.5 == float(grad_log1pexp(0.)[0]) + +# But now works at x = 100 as well. +assert 1.0 == float(grad_log1pexp(100.)[0]) +``` +Also notice how the gradient function implementation reuses an expression +(`tf.exp(x)`) computed during the forward pass, hence making the gradient +computation more efficient by avoiding redundant computation. + +## Building and training models + +In practice, your computation may have many parameters to be optimized (by +computing derivatives). Encapsulating them into re-usable classes/objects +makes the code easier to follow than writing a single top-level function with +many arguments. + +In fact, eager execution encourages use of the [Keras](https://keras.io)-style +"Layer" classes in the +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) +module. + +Furthermore, you may want to apply more sophisticated techniques to compute +parameter updates, such as those in +[`tf.train.Optimizer`](https://www.tensorflow.org/api_guides/python/train#Optimizers) +implementations. + +This next section walks through using the same `Optimizer` and `Layer` APIs used +to build trainable TensorFlow graphs in an environment where eager execution is +enabled. + +### Variables and Optimizers + +`tfe.Variable` objects store mutable `Tensor` values that can be accessed during +training, making automatic differentiation easier. In particular, parameters of +a model can be encapsulated in Python classes as variables. + +`tfe.gradients_function(f)` introduced earlier computes the derivatives of `f` +with respect to its arguments. However, it requires all parameters of interest +to be arguments of `f`, which becomes cumbersome when `f` depends on a large +number of trainable parameters. + +`tfe.implicit_gradients` is an alternative function with some useful properties: + +- It computes the derivatives of `f` with respect to all the `tfe.Variable`s + used by `f`. +- When the returned function is invoked, it returns a list of + (gradient value, Variable object) tuples. + +Representing model parameters as `Variable` objects, along with the use of +`tfe.implicit_gradients`, typically results in better encapsulation. For +example, the linear regression model described above can be written into a +class: + +```python +class Model(object): + def __init__(self): + self.W = tfe.Variable(5., name='weight') + self.B = tfe.Variable(10., name='bias') + + def predict(self, inputs): + return inputs * self.W + self.B + + +# The loss function to be optimized +def loss(model, inputs, targets): + error = model.predict(inputs) - targets + return tf.reduce_mean(tf.square(error)) + +# A toy dataset of points around 3 * x + 2 +NUM_EXAMPLES = 1000 +training_inputs = tf.random_normal([NUM_EXAMPLES]) +noise = tf.random_normal([NUM_EXAMPLES]) +training_outputs = training_inputs * 3 + 2 + noise + +# Define: +# 1. A model +# 2. Derivatives of a loss function with respect to model parameters +# 3. A strategy for updating the variables based on the derivatives +model = Model() +grad = tfe.implicit_gradients(loss) +optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) + +# The training loop +print("Initial loss: %f" % + loss(model, training_inputs, training_outputs).numpy()) +for i in range(201): + optimizer.apply_gradients(grad(model, training_inputs, training_outputs)) + if i % 20 == 0: + print("Loss at step %d: %f" % + (i, loss(model, training_inputs, training_outputs).numpy())) +print("Final loss: %f" % loss(model, training_inputs, training_outputs).numpy()) +print("W, B = %s, %s" % (model.W.numpy(), model.B.numpy())) +``` + +Output: + +``` +Initial loss: 69.693184 +Loss at step 0: 66.987854 +Loss at step 20: 30.553387 +Loss at step 40: 14.250237 +Loss at step 60: 6.955020 +Loss at step 80: 3.690550 +Loss at step 100: 2.229739 +Loss at step 120: 1.576032 +Loss at step 140: 1.283496 +Loss at step 160: 1.152584 +Loss at step 180: 1.093999 +Final loss: 1.067780 +W, B = 3.0114281, 2.0865183 +``` + +Using `implicit_gradients` avoids the need to provide all the trainable +parameters of the model as arguments to the `loss` function. + +### Using Keras and the Layers API + +[Keras](https://keras.io) is a popular API for defining model structures. The +[`tf.keras.layers`](https://www.tensorflow.org/api_docs/python/tf/keras/layers) +module provides a set of building blocks for models and is implemented using the +`tf.layers.Layer` subclasses in the +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers) +module. We encourage the use of these same building blocks when using +TensorFlow's eager execution feature. For example, the very same linear +regression model can be built using `tf.layers.Dense`: + +```python +class Model(object): + def __init__(self): + self.layer = tf.layers.Dense(1) + + def predict(self, inputs): + return self.layer(inputs) +``` + +The `tf.layers` API makes it more convenient to define more sophisticated +models. For example, the following will train an MNIST model: + +```python +class MNISTModel(object): + def __init__(self, data_format): + # 'channels_first' is typically faster on GPUs + # while 'channels_last' is typically faster on CPUs. + # See: https://www.tensorflow.org/performance/performance_guide#data_formats + if data_format == 'channels_first': + self._input_shape = [-1, 1, 28, 28] + else: + self._input_shape = [-1, 28, 28, 1] + self.conv1 = tf.layers.Conv2D(32, 5, + padding='same', + activation=tf.nn.relu, + data_format=data_format) + self.max_pool2d = tf.layers.MaxPooling2D( + (2, 2), (2, 2), padding='same', data_format=data_format) + self.conv2 = tf.layers.Conv2D(64, 5, + padding='same', + activation=tf.nn.relu, + data_format=data_format) + self.dense1 = tf.layers.Dense(1024, activation=tf.nn.relu) + self.dropout = tf.layers.Dropout(0.5) + self.dense2 = tf.layers.Dense(10) + + def predict(self, inputs): + x = tf.reshape(inputs, self._input_shape) + x = self.max_pool2d(self.conv1(x)) + x = self.max_pool2d(self.conv2(x)) + x = tf.layers.flatten(x) + x = self.dropout(self.dense1(x)) + return self.dense2(x) + +def loss(model, inputs, targets): + return tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits( + logits=model.predict(inputs), labels=targets)) + + +# Load the training and validation data +from tensorflow.examples.tutorials.mnist import input_data +data = input_data.read_data_sets("./mnist_data", one_hot=True) + +# Train +device = "gpu:0" if tfe.num_gpus() else "cpu:0" +model = MNISTModel('channels_first' if tfe.num_gpus() else 'channels_last') +optimizer = tf.train.AdamOptimizer(learning_rate=1e-4) +grad = tfe.implicit_gradients(loss) +for i in range(20001): + with tf.device(device): + (inputs, targets) = data.train.next_batch(50) + optimizer.apply_gradients(grad(model, inputs, targets)) + if i % 100 == 0: + print("Step %d: Loss on training set : %f" % + (i, loss(model, inputs, targets).numpy())) +print("Loss on test set: %f" % loss(model, data.test.images, data.test.labels).numpy()) +``` + +For a more complete example, see +[`tensorflow/contrib/eager/python/examples/mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) + +### Checkpointing trained variables + +TensorFlow Variables (`tfe.Variable`) provides a way to represent shared, +persistent state of your model. The `tfe.Saver` class (which is a thin wrapper +over the +[`tf.train.Saver`](https://www.tensorflow.org/api_docs/python/tf/train/Saver) +class) provides a means to save and restore variables to and from _checkpoints_. + +For example: + +```python +# Create variables. +x = tfe.Variable(10., name='x') +y = tfe.Variable(5., name='y') + +# Create a Saver. +saver = tfe.Saver([x, y]) + +# Assign new values to the variables and save. +x.assign(2.) +saver.save('/tmp/ckpt') + +# Change the variable after saving. +x.assign(11.) +assert 16. == (x + y).numpy() # 11 + 5 + +# Restore the values in the checkpoint. +saver.restore('/tmp/ckpt') + +assert 7. == (x + y).numpy() # 2 + 5 +``` + +### `tfe.Network` + +You may often want to organize your models using classes, like the `MNISTModel` +class described above. We recommend inheriting from the `tfe.Network` class as +it provides conveniences like keeping track of all model variables and methods +to save and restore from checkpoints. + +Sub-classes of `tfe.Network` may register `Layer`s (like classes in +[`tf.layers`](https://www.tensorflow.org/api_docs/python/tf/layers), +or [Keras +layers](https://www.tensorflow.org/api_docs/python/tf/keras/layers)) +using a call to `self.track_layer()` and define the computation in an +implementation of `call()`. + +Note that `tf.layers.Layer` objects (like `tf.layers.Dense`) create variables +lazily, when the first input is encountered. + +For example, consider the following two-layer neural network: + +```python +class TwoLayerNet(tfe.Network): + def __init__(self): + super(TwoLayerNet, self).__init__() + self.layer1 = self.track_layer( + tf.layers.Dense(2, activation=tf.nn.relu, use_bias=False)) + self.layer2 = self.track_layer(tf.layers.Dense(3, use_bias=False)) + + def call(self, x): + return self.layer2(self.layer1(x)) + +net = TwoLayerNet() + +# No variables created yet +assert 0 == len(net.variables) + +# They are created on first input: +inp = tf.constant([[1.]]) + +# Since input is a 1x1 matrix, net.l1 has 2 units and net.l2 has 3 units, +# the output is the product of a 1x1 matrix with a 1x2 matrix with a 2x3 +# matrix. +assert [1, 3] == net(inp).shape.as_list() # Invoke net; get output shape. +assert 1 == len(net.layer1.variables) +assert 1 == len(net.layer2.variables) +assert 2 == len(net.variables) # weights for each layer. +assert [1, 2] == net.variables[0].shape.as_list() # weights of layer1. +assert [2, 3] == net.variables[1].shape.as_list() # weights of layer2. +``` + +The `tfe.Network` class is itself a sub-class of `tf.layers.Layer`. This allows +instances of `tfe.Network` to be embedded in other networks. For example: + +```python +class ThreeLayerNet(tfe.Network): + def __init__(self): + super(ThreeLayerNet, self).__init__() + self.a = self.track_layer(TwoLayerNet()) + self.b = self.track_layer(tf.layers.Dense(4, use_bias=False)) + + def call(self, x): + return self.b(self.a(x)) + +net = ThreeLayerNet() + +assert [1, 4] == net(inp).shape.as_list() +assert 3 == len(net.variables) +assert [1, 2] == net.variables[0].shape.as_list() +assert [2, 3] == net.variables[1].shape.as_list() +assert [3, 4] == net.variables[2].shape.as_list() +``` + +See more examples in +[`tensorflow/contrib/eager/python/examples`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples). + +`tfe.Saver` in combination with `tfe.restore_variables_on_create` provides a +convenient way to save and load checkpoints without changing the program once +the checkpoint has been created. For example, we can set an objective for the +output of our network, choose an optimizer, and a location for the checkpoint: + +```python +objective = tf.constant([[2., 3., 4., 5.]]) +optimizer = tf.train.AdamOptimizer(0.01) +checkpoint_directory = '/tmp/tfe_example' +checkpoint_prefix = os.path.join(checkpoint_directory, 'ckpt') +net = ThreeLayerNet() +``` + +Note that variables have not been created yet. We want them to be restored from +a checkpoint, if one exists, so we create them inside a +`tfe.restore_variables_on_create` context manager. Then our training loop is the +same whether starting training or resuming from a previous checkpoint: + +```python +with tfe.restore_variables_on_create( + tf.train.latest_checkpoint(checkpoint_directory)): + global_step = tf.train.get_or_create_global_step() + for _ in range(100): + loss_fn = lambda: tf.norm(net(inp) - objective) + optimizer.minimize(loss_fn, global_step=global_step) + if tf.equal(global_step % 20, 0): + print("Step %d, output %s" % (global_step.numpy(), + net(inp).numpy())) + all_variables = ( + net.variables + + optimizer.variables() + + [global_step]) + # Save the checkpoint. + tfe.Saver(all_variables).save(checkpoint_prefix, global_step=global_step) +``` + +The first time it runs, `Network` variables are initialized randomly. Then the +output is trained to match the objective we've set: + +``` +Step 20, output [[ 0.03575622 0.29863232 0.03474367 0.24735749]] +Step 40, output [[ 0.40646029 0.9856872 0.46851286 0.95358551]] +Step 60, output [[ 1.74541104 2.800704 1.79055595 2.74783421]] +Step 80, output [[ 2.14977384 3.44340849 3.96120024 5.16242075]] +Step 100, output [[ 1.99943113 3.02364397 3.93500996 4.9610076 ]] +``` + +In subsequent iterations, variables are initialized with the values read from +the latest checkpoint. Running the same code again, we continue from where we +left off: + +``` +Step 120, output [[ 1.99234128 3.0271616 3.98732996 4.96401167]] +Step 140, output [[ 2.00133467 3.01270437 4.00616646 5.00406504]] +Step 160, output [[ 1.99647415 2.9956708 3.99064088 4.99632359]] +Step 180, output [[ 2.00699997 3.00904822 4.00706148 5.01193142]] +Step 200, output [[ 1.98334622 2.98249531 3.97375059 4.97123432]] +``` + + +### Summaries, metrics and TensorBoard + +[TensorBoard](https://www.tensorflow.org/get_started/summaries_and_tensorboard) +is a popular tool for understanding, debugging and optimizing the model training +process. To benefit from the visualizations offered by TensorBoard, summary +events need to be written during the course of execution of your program. You +might find many Tensorflow programs that include the +[`tf.summary`](https://www.tensorflow.org/api_guides/python/summary) operations +during graph construction. + +`tf.summary` operations are *not* compatible with eager execution, but an +equivalent alternative exists in +[`tf.contrib.summary`](https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/summary) +that is compatible with both eager execution and graph construction. + +During model construction simply insert summary operations like +`tf.contrib.summary.scalar`. These operations do nothing by default, unless a +summary writer is currently active and a writing policy is set. + +For example, to record summaries once every 100 global steps, use: + +```python +tf.train.get_or_create_global_step() # Ensuring the global step variable exists +writer = tf.contrib.summary.create_summary_file_writer(logdir) + +for _ in range(iterations): + with writer.as_default(): + with tf.contrib.summary.record_summaries_every_n_global_steps(100): + # your model code goes here + tf.contrib.summary.scalar('loss', loss) + # ... +``` + +See the full mnist example in +[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) +for a full model using `tf.contrib.summary`. + +Similarly to summaries, the metrics in `tf.metrics` are currently not compatible +with eager execution. We instead provide object-oriented metrics in the +`tfe.metrics` package, which are compatible with graph construction as well. + +Metrics in the `tfe.metrics`, such as `tfe.metrics.Mean` and +`tfe.Metrics.Accuracy`, all implement an intuitive object-oriented +interface. Here's an example of how to use the `tfe.metrics.Mean` metric: + +```python +# Metrics are objects, which can be created and destroyed. +my_mean = tfe.metrics.Mean(name='my_mean') +# While a metric is active, you can call it as a function to accumulate into its +# internal state. +my_mean(0.0) +my_mean(10.0) +# Once you've finished updating the metric, you can get its result. In this case +# a simple average over all the calls to it. If a summary writer is active the +# metric will write the appropriate summaries using the metric name. +assert 5.0 == my_mean.result().numpy() +``` + +For a full example of a model using metrics for evaluation, see the mnist +example in +[`tensorflow/contrib/eager/python/examples/mnist`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist). + +### Input Pipelines + +The discussion above has been centered around the computation executed by your +model. The +[`tf.data`](https://www.tensorflow.org/api_docs/python/tf/data) +module provides APIs to build complex input pipelines from simple, reusable +pieces. + +If you're familiar with constructing `tf.data.Dataset` objects when building +TensorFlow graphs, the same API calls are used when eager execution is enabled. +However, the process of iterating over elements of the dataset differs between +eager execution and graph construction. When eager execution is enabled, the +discussion on iterator creation using `make_one_shot_iterator()` and +`get_next()` in the +[Programmer's Guide](https://www.tensorflow.org/programmers_guide/datasets) is +*not* applicable. Instead, a more Pythonic `Iterator` class is available. + +For example: + +```python +# Create a source Dataset from in-memory numpy arrays. +# For reading from files on disk, you may want to use other Dataset classes +# like the TextLineDataset or the TFRecordDataset. +dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4, 5, 6]) + +# Apply transformations, shuffling, batching etc. +dataset = dataset.map(tf.square).shuffle(2).batch(2) + +# Use tfe.Iterator to iterate over the dataset. +for x in tfe.Iterator(dataset): + print(x) +``` + +Output: + +``` +tf.Tensor([4 9], shape=(2,), dtype=int32) +tf.Tensor([16 25], shape=(2,), dtype=int32) +tf.Tensor([36 1], shape=(2,), dtype=int32) +``` + +## Interoperating with Graphs + +Eager execution improves the process of model development in Python; however, +because it is in its earliest stages, it does not yet support some features +available to [TensorFlow +graphs](https://www.tensorflow.org/get_started/get_started#the_computational_graph) +that are desirable when deploying models in production. In particular, eager +execution does not yet support distributed training, exporting models (to other +[programming languages](https://www.tensorflow.org/api_docs/), [TensorFlow +serving](https://www.tensorflow.org/serving/), and mobile applications), and +various memory and computation optimizations that are applied to TensorFlow's +dataflow graphs. + +That said, the APIs used to build modes are exactly the same whether executing +eagerly or constructing graphs. This means that you can iteratively develop your +model with eager execution enabled and later, if needed, use the same code to +reap the benefits of representing models as computational graphs. + +For example, +[`mnist.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist.py) +defines a model that is eagerly executed. That same code is used to construct +and execute a graph in +[`mnist_graph_test.py`](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist/mnist_graph_test.py). + +Other models in the [examples +directory](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/) +demonstrate this as well. + +Some differences worth noting: + +- There is no notion of a `tf.placeholder` or a `tf.Session` when eager + execution is enabled. +- Many properties on the `tf.Tensor` object, like `tf.Tensor.name`, + `tf.Tensor.op`, `tf.Tensor.inputs` are not meaningful when eager execution + is enabled and their use will raise an `AttributeError`. +- To use `tfe.implicit_gradients` in graph construction, variables must be + created with [`use_resource=True`] provided to + [`tf.get_variable()`](https://www.tensorflow.org/api_docs/python/tf/get_variable) + or + [`tf.variable_scope()`](https://www.tensorflow.org/api_docs/python/tf/variable_scope). +- Some API calls (such as the functional-style `tf.layers.dense`, + `tf.layers.conv2d`) are not compatible with eager execution. Use of such + methods should raise an error indicating the alternative (e.g., the + `tf.layers.Dense` and `tf.layers.Conv2D` classes). + +## What next? + +Please give eager execution a spin. This feature is in early stages and is +evolving, so we welcome your feedback via issues on GitHub (see [known +issues](https://github.com/tensorflow/tensorflow/labels/comp:eager)). + +You may want to browse through some sample code, including benchmarks for some: + +- [Linear Regression](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/linear_regression) +- [MNIST handwritten digit classifier](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/mnist) +- [ResNet50 image classification](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/resnet50) +- [RNN to generate colors](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_colorbot) +- [RNN language model](https://www.tensorflow.org/code/tensorflow/contrib/eager/python/examples/rnn_ptb) + diff --git a/tensorflow/contrib/eager/python/metrics.py b/tensorflow/contrib/eager/python/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..3e3100427376ddd480b50d967cf53e7831aaefb2 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics.py @@ -0,0 +1,26 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics namespace.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint:disable=wildcard-import +from tensorflow.contrib.eager.python.metrics_impl import * +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = ['Accuracy', 'Mean', 'Metric'] +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/eager/python/metrics_impl.py b/tensorflow/contrib/eager/python/metrics_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..aa359b7a0d7d89e8788c323d1621798d1a22b658 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_impl.py @@ -0,0 +1,313 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Metrics classes for computing the output of an evaluation.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.summary import summary_ops +from tensorflow.python.eager import context +from tensorflow.python.eager import function +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope + + +_to_replace = re.compile("[^A-Za-z0-9.]") + + +class Metric(object): + """A metric holds state for aggregating statistics over an evaluation run. + + Example use with eager execution: + + ```python + m = SomeMetric(...) + for input in ...: + m(input) + print(m.result()) + ``` + + Example use with graph execution: + + ```python + m = SomeMetric(...) + m_placeholder = tf.placeholder(...) + m_update = m(m_placeholder) + # Variables defined in first call, so get the initialization op afterwards. + m_init = m.init_variables() # or tf.global_variables_initializer() + m_result = m.result() + with tf.Session() as sess: + sess.run(m_init) + for input in ...: + sess.run(m_update, feed_dict={m_placeholder: input}) + print(sess.run(m_result)) + ``` + + Descendants will implement: + * `build()`: All variables should be created in this method, by calling + `self.add_variable()` as in: `self.var = self.add_variable(...)` + build() will be called in the first invocation of `__call__()`, with + the same arguments passed `call()`. + * `call()`: Has all updates to variables, as in: + self.var.assign_add(...) + * `result()`: Computes and returns a final value for the metric + from the variables in `self`. + + Decendants may override `aggregate()`, but usually won't need to. It + adds in the state from a list of metrics of the same type as `self`. + (Default is to sum all the variables.) Note that users should not call + `aggregate()`, it is for use by TensorFlow infrastructure. + """ + + def __init__(self, name=None): + self._built = False + self._vars = [] + self._initial_values = {} + self._updates = [] + name = name or self.__class__.__name__ + # Replace things like spaces in name to create a valid scope name. + scope_name = _to_replace.sub("_", name) + # We create the variable scope now to get the unique name that will + # be used as a variable prefix when build() calls add_variable(). + with variable_scope.variable_scope( + scope_name, use_resource=True, reuse=False) as scope: + pos = scope.name.rfind(scope_name) + self._name = name + scope.name[pos + len(scope_name):] + self._scope = scope + if context.in_graph_mode(): + # We make self.call() into a graph callable here, so that we can + # return a single op that performs all of the variable updates. + self._construction_scope = ops.get_default_graph().as_default + self.call = function.defun(self.call) + else: + self._construction_scope = context.eager_mode + + # ---- API for users ---- + def __call__(self, *args, **kwargs): + """Returns op to execute to update this metric for these inputs. + + Returns None if eager execution is enabled. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, passed on to `call()`. + """ + if not self._built: + with variable_scope.variable_scope( + self._scope), self._construction_scope(): + self.build(*args, **kwargs) + self._built = True + return self.call(*args, **kwargs) + + @property + def name(self): + return self._name + + @property + def variables(self): + return self._vars + + def init_variables(self): + """Initializes this Metric's variables. + + Should be called after variables are created in the first execution + of `__call__()`. If using graph execution, the return value should be + `run()` in a session before running the op returned by `__call__()`. + (See example above.) + + Returns: + If using graph execution, this returns an op to perform the + initialization. Under eager execution, the variables are reset to their + initial values as a side effect and this function returns None. + """ + if context.in_graph_mode(): + return control_flow_ops.group([v.initializer for v in self._vars]) + for v in self._vars: + v.assign(self._initial_values[v]) + + # ---- To be implemented by descendants --- + def build(self, *args, **kwargs): + """Method to create variables. + + Called by `__call__()` before `call()` for the first time. + + Args: + *args: + **kwargs: The arguments to the first invocation of `__call__()`. + `build()` may use the shape and/or dtype of these arguments + when deciding how to create variables. + """ + raise NotImplementedError("Metrics must define a build() member function") + + def call(self, *args, **kwargs): + """Accumulates statistics for the metric. Users should use __call__ instead. + + Note: This function is executed as a graph function in graph mode. + This means: + a) Operations on the same resource are executed in textual order. + This should make it easier to do things like add the updated + value of a variable to another, for example. + b) You don't need to worry about collecting the update ops to execute. + All update ops added to the graph by this function will be executed. + As a result, code should generally work the same way with graph or + eager execution. + + Args: + *args: + **kwargs: A mini-batch of inputs to the Metric, as passed to + `__call__()`. + """ + raise NotImplementedError("Metrics must define a call() member function") + + def result(self): # TODO(josh11b): Add an optional summary_writer parameter. + """Computes and returns a final value for the metric.""" + raise NotImplementedError("Metrics must define a result() member function") + + # We can support two different strategies of for doing data-parallel + # distributed metric computations: + # * Put metric variables on the first device and rely on small + # bandwidth needed to do updates. (Doesn't require any particular + # code in Metric implementations.) + # * Ask each type of metric to define an aggregation method to run + # at the end of eval to merge across devices. Note: this is good + # for the use case where they want to record the metric's state + # for each example and then later decide which examples they want + # to aggregate over. (Recommended -- not too much harder and adds + # flexibility over previous option.) + # I'm going with the second strategy since we can define a default + # implementation of aggregate() that will work for most descendants. + def aggregate(self, metrics): + """Adds in the state from a list of metrics. + + Default implementation sums all the metric variables. + + Args: + metrics: A list of metrics with the same type as `self`. + + Raises: + ValueError: If metrics contains invalid data. + """ + for m in metrics: + if type(self) != type(m): # pylint: disable=unidiomatic-typecheck + raise TypeError("All metrics must be the same type, '%s' != '%s'." % + (type(self), type(m))) + # pylint: disable=protected-access + for i in range(len(self._vars)): + if any(m._vars[i].name != self._vars[i].name for m in metrics): + raise ValueError("All metrics must have variables in the same order.") + self._vars[i].assign_add(math_ops.add_n([m._vars[i] for m in metrics])) + # pylint: enable=protected-access + + # ---- For use by descendants --- + def add_variable(self, name, shape=None, dtype=None, initializer=None): + """***Only for use by descendants of Metric***.""" + if self._built: + raise RuntimeError("Can't call add_variable() except in build().") + collections = None if context.in_eager_mode() else [ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ] + v = variable_scope.get_variable( + name, + shape, + dtype, + initializer, + trainable=False, + collections=collections, + use_resource=True) + self._vars.append(v) + if context.in_eager_mode(): + self._initial_values[v] = v.value() + return v + + +class Mean(Metric): + """Computes the (weighted) mean of the given values.""" + # TODO(josh11b): Maybe have a dtype argument that defaults to tf.float64? + # Or defaults to type of the input if it is tf.float32, else tf.float64? + + def __init__(self, name=None, dtype=dtypes.float64): + super(Mean, self).__init__(name=name) + self.dtype = dtype + + def build(self, *args, **kwargs): + # build() does not use call's arguments, by using *args, **kwargs + # we make it easier to inherit from Mean(). + del args, kwargs + self.numer = self.add_variable(name="numer", shape=(), + dtype=self.dtype, + initializer=init_ops.zeros_initializer) + self.denom = self.add_variable(name="denom", shape=(), + dtype=self.dtype, + initializer=init_ops.zeros_initializer) + + def call(self, values, weights=None): + """Accumulate statistics for computing the mean. + + For example, if values is [1, 3, 5, 7] then the mean is 4. + If the weights were specified as [1, 1, 0, 0] then the mean would be 2. + + Args: + values: Tensor with the per-example value. + weights: Optional weighting of each example. Defaults to 1. + """ + if weights is None: + self.denom.assign_add( + math_ops.cast(array_ops.identity(array_ops.size(values)), self.dtype)) + values = math_ops.reduce_sum(values) + self.numer.assign_add(math_ops.cast(values, self.dtype)) + else: + weights = math_ops.cast(weights, self.dtype) + self.denom.assign_add(math_ops.reduce_sum(weights)) + values = math_ops.cast(values, self.dtype) * weights + self.numer.assign_add(math_ops.reduce_sum(values)) + + def result(self): + t = self.numer / self.denom + summary_ops.scalar(name=self.name, tensor=t) + return t + + +class Accuracy(Mean): + """Calculates how often `predictions` matches `labels`.""" + + def __init__(self, name=None, dtype=dtypes.float64): + super(Accuracy, self).__init__(name=name, dtype=dtype) + + def call(self, labels, predictions, weights=None): + """Accumulate accuracy statistics. + + For example, if labels is [1, 2, 3, 4] and predictions is [0, 2, 3, 4] + then the accuracy is 3/4 or .75. If the weights were specified as + [1, 1, 0, 0] then the accuracy would be 1/2 or .5. + + `labels` and `predictions` should have the same shape and type. + + Args: + labels: Tensor with the true labels for each example. One example + per element of the Tensor. + predictions: Tensor with the predicted label for each example. + weights: Optional weighting of each example. Defaults to 1. + """ + matches = math_ops.equal(labels, predictions) + matches = math_ops.cast(matches, dtypes.float64) + super(Accuracy, self).call(matches, weights=weights) diff --git a/tensorflow/contrib/eager/python/metrics_test.py b/tensorflow/contrib/eager/python/metrics_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f5973bd11a02230d30f8cf1b2961125f154283 --- /dev/null +++ b/tensorflow/contrib/eager/python/metrics_test.py @@ -0,0 +1,168 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Metrics.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile + +from tensorflow.contrib.eager.python import metrics +from tensorflow.contrib.summary import summary_ops +from tensorflow.contrib.summary import summary_test_util +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.training import training_util + + +class MetricsTest(test.TestCase): + + def testMean(self): + m = metrics.Mean() + m([1, 10, 100]) + m(1000) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testVariableCollections(self): + with context.graph_mode(), ops.Graph().as_default(): + m = metrics.Mean() + m(1000) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES))) + self.assertEqual( + set(m.variables), + set(ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) + + def testInitVariables(self): + m = metrics.Mean() + m([1, 10, 100, 1000]) + m([10000.0, 100000.0]) + self.assertEqual(111111.0/6, m.result().numpy()) + m.init_variables() + m(7) + self.assertEqual(7.0, m.result().numpy()) + + def testWriteSummaries(self): + m = metrics.Mean() + m([1, 10, 100]) + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name="t0").as_default(), summary_ops.always_record_summaries(): + m.result() # As a side-effect will write summaries. + + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 37.0) + + def testWeightedMean(self): + m = metrics.Mean() + m([1, 100, 100000], weights=[1, 0.2, 0.3]) + m([500000, 5000, 500]) # weights of 1 each + self.assertNear(535521/4.5, m.result().numpy(), 0.001) + + def testMeanDtype(self): + # Can override default dtype of float64. + m = metrics.Mean(dtype=dtypes.float32) + m([0, 2]) + self.assertEqual(1, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + + def testAccuracy(self): + m = metrics.Accuracy() + m([0, 1, 2, 3], [0, 0, 0, 0]) # 1 correct + m([4], [4]) # 1 correct + m([5], [0]) # 0 correct + m([6], [6]) # 1 correct + m([7], [2]) # 0 correct + self.assertEqual(3.0/8, m.result().numpy()) + self.assertEqual(dtypes.float64, m.dtype) + self.assertEqual(dtypes.float64, m.result().dtype) + + def testWeightedAccuracy(self): + m = metrics.Accuracy() + # 1 correct, total weight of 2 + m([0, 1, 2, 3], [0, 0, 0, 0], weights=[1, 1, 0, 0]) + m([4], [4], weights=[0.5]) # 1 correct with a weight of 0.5 + m([5], [0], weights=[0.5]) # 0 correct, weight 0.5 + m([6], [6]) # 1 correct, weight 1 + m([7], [2]) # 0 correct, weight 1 + self.assertEqual(2.5/5, m.result().numpy()) + + def testAccuracyDtype(self): + # Can override default dtype of float64. + m = metrics.Accuracy(dtype=dtypes.float32) + m([0, 0], [0, 1]) + self.assertEqual(0.5, m.result().numpy()) + self.assertEqual(dtypes.float32, m.dtype) + self.assertEqual(dtypes.float32, m.result().dtype) + + def testTwoMeans(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean() + m1(0) + m2 = metrics.Mean() + m2(2) + self.assertAllEqual(0.0, m1.result()) + self.assertAllEqual(2.0, m2.result()) + + def testNamesWithSpaces(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + m1 = metrics.Mean("has space") + m1(0) + self.assertEqual(m1.name, "has space") + self.assertEqual(m1.numer.name, "has_space/numer:0") + + def testGraph(self): + with context.graph_mode(), self.test_session() as sess: + m = metrics.Mean() + p = array_ops.placeholder(dtypes.float32) + accumulate = m(p) + init_op = m.init_variables() + init_op.run() + sess.run(accumulate, feed_dict={p: [1, 10, 100]}) + sess.run(accumulate, feed_dict={p: 1000}) + sess.run(accumulate, feed_dict={p: [10000, 100000]}) + self.assertAllEqual(m.result().eval(), 111111.0/6) + # Second init resets all the variables. + init_op.run() + sess.run(accumulate, feed_dict={p: 7}) + self.assertAllEqual(m.result().eval(), 7) + + def testTwoMeansGraph(self): + # Verify two metrics with the same class and name don't + # accidentally share state. + with context.graph_mode(): + m1 = metrics.Mean() + m1(0) + with self.assertRaises(ValueError): + m2 = metrics.Mean() + m2(2) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py new file mode 100644 index 0000000000000000000000000000000000000000..97feaec30ed066503ef8ce75cbd5af04ea2ef6bf --- /dev/null +++ b/tensorflow/contrib/eager/python/network.py @@ -0,0 +1,803 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A Network is a composition of Layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import os +import weakref + +from tensorflow.python.eager import context +from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.framework import ops +from tensorflow.python.layers import base +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import checkpoint_utils +from tensorflow.python.training import saver as saver_lib +from tensorflow.python.training import training_util + +# pylint: disable=protected-access +# Explanation for protected-access disable: Network has lots of same-class and +# parent-class references across different objects, and some to private +# functions in base.py which should be reused. + + +_DeferredRestoration = collections.namedtuple( + + "_DeferredRestoration", + [ + # The map_func to use (either user-specified or the default). + "map_func", + # Boolean, True if the user specified an explicit map_func, for error + # messages. + "map_func_is_user", + # A mapping from checkpoint names to initial values of not-yet-created + # variables which should be restored. These values come from parsing a + # checkpoint. + "checkpointed_variables_to_restore", + # A mapping from checkpoint name to variable objects of variables which + # have already been restored, for error checking. + "restored_variables", + # The session to restore with (if in graph mode). + "session", + # Names of the Network where the restore was requested, for error + # messages. + "network_name", + "network_scope_name" + ]) + + +def _default_naming_conflict_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The default checkpoint variable name mapping strategy for Network " + "'%s' resulted in a naming conflict. We attempted to strip off the " + "variable prefix for the Network ('%s'), but this resulted in two " + "variables named '%s' (originally '%s' and '%s'). This should only " + "happen when using variable sharing (i.e. the Network contains Networks " + "or Layers which were first added to another Network, and therefore " + "have that Network's variable prefix). One solution is to pass " + "`map_func=lambda n: n` to Network.save and Network.restore to use " + "fully qualified variable names in the checkpoint, although this will " + "require that the variable prefix of the Network being restored into " + "is also '%s'. You may alternatively write an arbitrary mapping.") + % ( + network_name, network_scope_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, network_scope_name + )) + + +def _restore_custom_map_func_error_message( + mapped_name, first_variable, second_variable, + network_name, network_scope_name): + return ( + ("The map_func passed to Network.restore for the Network '%s' " + "resulted in two variables named '%s' (originally '%s' and '%s'). Since " + "this is also an error on Network.save, this Network was " + "probably not saved with this map_func. Note that map_func " + "always maps from full variable names to checkpoint names; " + "there is no need to specify an inverse mapping.\n\n" + "Try stripping less from the variable names, or renaming parts " + "of the Network. For reference, variables created by sub-Layers " + "of this Network are prefixed with '%s', but if they are " + "re-used after being added to another Network they will have " + "that Network's full variable prefix instead.") % ( + network_name, mapped_name, + first_variable._shared_name, + second_variable._shared_name, + network_scope_name)) + + +def _make_custom_getter_for_deferred_restorations(): + """Returns a custom getter which searches `deferred_restorations`. + + Returns: A tuple of (_custom_getter, deferred_restorations) + _custom_getter: The getter which should be added to variable_scopes where + variables will be created. + deferred_restorations: A list for _DeferredRestoration objects. Typically + empty when the getter is set, and expanded as deferred restorations are + requested. All new deferred restorations should be appended to the end of + the list, where they will have priority over older deferred restorations. + """ + deferred_restorations = [] + + def _custom_getter(getter, name, shape=None, dtype=None, + initializer=None, + *args, **kwargs): + """A custom getter which processes deferred restorations.""" + # Iterate over restorations, newest first (newer restorations will take + # precedence over older restorations, just like with immediate restorations + # into existing variables). + delayed_restoration = None + found_value = False + value_to_restore = None + for delayed_restoration in reversed( + deferred_restorations): + checkpoint_name = delayed_restoration.map_func(name) + if (checkpoint_name + in delayed_restoration.checkpointed_variables_to_restore): + found_value = True + value_to_restore = ( + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name]) + if found_value: + break + # value_to_restore may be False because this variable is not in any + # checkpoint we are restoring, or None because we have explicitly set it to + # None when it was previously fetched. In either case, we don't need to + # set an initializer. + if found_value and value_to_restore is not None: + initializer = value_to_restore + shape = None + variable = getter(name, shape=shape, dtype=dtype, initializer=initializer, + *args, **kwargs) + if found_value and value_to_restore is not None: + # Mark as already restored from this checkpoint. + delayed_restoration.checkpointed_variables_to_restore[ + checkpoint_name] = None + if context.in_graph_mode(): + delayed_restoration.session.run(variable.initializer) + if found_value: + # Error checking should run even if we've already restored a value. + if delayed_restoration.restored_variables.setdefault( + checkpoint_name, variable) is not variable: + # Naming conflict. We've tried to initialize two variables with the + # same value from the checkpoint. + if delayed_restoration.map_func_is_user: + raise ValueError( + _restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + else: + raise ValueError( + _default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=delayed_restoration.restored_variables[ + checkpoint_name], + second_variable=variable, + network_name=delayed_restoration.network_name, + network_scope_name=delayed_restoration.network_scope_name)) + return variable + return _custom_getter, deferred_restorations + + +class Network(base.Layer): + """Represents the composition of a set of Layers. + + TODO(josh11b,ashankar): + - Should "trainable" be changeable on the Network object? + - Do we allow add_variable in Network? + - Detect layers used in __call__ that weren't registered with track_layer. + - Convert inputs to __call__ to tensors. + - Prevent variables from being created after the first __call__? + (Think about restoring from a checkpoint). + """ + + def __init__(self, name=None): + if isinstance(name, variable_scope.VariableScope): + raise ValueError("VariableScopes are not valid Network names.") + if name is not None and "/" in name: + raise ValueError( + "Forward slashes ('/') are not allowed in Network names.") + super(Network, self).__init__(name=name) + self._layers = [] + self._sub_layer_name_uids = collections.defaultdict(int) + # Initially None, but set to False for networks which are first built as + # top-level. + self._first_parent = None # A weak reference to our first parent. + self._non_network_sublayers = [] + self._owned_layers = {} + # The scope to use if we end up without a parent. + self._default_parent_variable_scope = variable_scope.get_variable_scope() + self._custom_getter, self._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + + def _init_set_name(self, name): + # Anonymous Networks (name=None) defer setting a final name until they are + # (1) added to another Network, or (2) built/called (where (2) is only used + # for a "top level" network). + # + # However, if we were provided an explicit name (name is not None), that + # will always be the final name of the Network; if it turns out not to be + # unique or if variable names can't be prefixed by it we will throw an + # error. + self._name = name + self._base_name = None + + def _finalize_name(self, parent_network): + if not self._name: + if not parent_network: + name_uid_map = base._get_default_graph_uid_map() + else: + name_uid_map = parent_network._sub_layer_name_uids + # Were were not passed a name explicitly (or it was blank), so this is an + # anonymous Network. We make up a unique name. + if parent_network: + avoid_names = parent_network._owned_layers + else: + avoid_names = None + self._name, self._base_name = self._make_unique_name( + name_uid_map=name_uid_map, avoid_names=avoid_names) + if self._first_parent is None or (self._first_parent # False = no parent + and self._first_parent() is None): + # Save a pointer to the parent Network so that we can later check that the + # scope name we get is correct. + if not parent_network: + self._first_parent = parent_network + else: + self._first_parent = weakref.ref(parent_network) + + def _set_scope(self, scope=None): + if self._scope is None: + if not self._first_parent: + first_parent = self._first_parent + else: + first_parent = self._first_parent() + if first_parent is None: + # If we were never added to another Network, or that Network has beed + # garbage collected before being called, then we're a top-level Network. + self._finalize_name( + # Use False to make sure the value sticks and we don't inherit a + # parent if we're added to a network later. + parent_network=False) + if scope is not None: + raise ValueError("Networks may not be created with explicit scopes.") + if first_parent: + first_parent._set_scope() + parent_scope = first_parent._scope + else: + parent_scope = self._default_parent_variable_scope + with variable_scope.variable_scope(parent_scope): + # Make sure variables with this prefix will be unique. + with variable_scope.variable_scope( + None, use_resource=True, default_name=self._name) as scope: + self._scope = scope + scope_name = scope.name + suffix_start = scope_name.rfind("/") + 1 + # rfind is -1 if there is no slash in the string, in which case the + # suffix starts at the beginning of the string (there is no prefix). + scope_suffix = scope_name[suffix_start:] + scope_prefix = scope_name[:suffix_start] + if scope_suffix != self._name: + raise ValueError( + ("A Network named '%s' already exists (or a variable_scope was " + "created with this name). Names must be unique.") % ( + self._name,)) + if (first_parent + and scope_prefix[:-1] != first_parent._scope.name): + raise ValueError( + ("Network variable names must match a nesting of sub-Network " + "names. Expected prefix '%s' from parent network, but got " + "'%s' when attempting to create a variable_scope for Network " + "'%s'. Likely an explicit variable_scope was inserted into " + "the nesting.") % ( + first_parent._scope.name, + scope_prefix[:-1], + self._name)) + elif not first_parent and scope_prefix: + # For the case when this Network is not nested inside any other + # Network, but is in a variable_scope. This is an error for now. + raise ValueError( + "Creating Networks inside named variable_scopes is currently " + "not supported (to ensure that variable names match the names " + "of Networks in which they were first created). To set " + "options, try `with tf.variable_scope(''):`. If this " + "limitation bothers you, please file a feature request.") + for non_network_sublayer in self._non_network_sublayers: + self._set_scope_for_nonnetwork_sublayer(non_network_sublayer) + + def _set_scope_for_nonnetwork_sublayer(self, sublayer): + if sublayer._scope is None: + if sublayer._first_parent is None: + constituent_first_parent = None + else: + constituent_first_parent = sublayer._first_parent() + if constituent_first_parent: + constituent_first_parent._set_scope() + parent_scope = constituent_first_parent._scope + else: + self._finalize_name(False) + raise ValueError( + ("The parent of a Layer added to Network %s was garbage collected " + "before the Layer was built. If this limitation bothers you " + "please, comment on " + "https://github.com/tensorflow/tensorflow/issues/14164.") % + (self.name,)) + with variable_scope.variable_scope(parent_scope): + # Horrid hack to make Layer variable names which are direct + # sub-layers of Networks conform to the Network variable naming + # conventions. + with variable_scope.variable_scope( + None, use_resource=True, + default_name=sublayer.name) as sub_scope: + sublayer._scope = sub_scope + + @base.Layer.name.getter + def name(self): + if self._name is None: + raise ValueError( + "The network does not yet have a final name, but a name was " + "requested for it. Networks get a name when they are added to " + "another Network via track_layer, or when they are first " + "called/built.") + return self._name + + def track_layer(self, layer): + """Track a Layer in this Network. + + `Network` requires that all `Layer`s used in `call()` be tracked so that the + `Network` can export a complete list of variables. + + Args: + layer: A `tf.layers.Layer` object. + + Returns: + The passed in `layer`. + + Raises: + RuntimeError: If __init__ has not been called. + TypeError: If `layer` is the wrong type. + ValueError: If a `Layer` with the same name has already been added. + """ + if not hasattr(self, "_layers"): + raise RuntimeError("Need to call Network.__init__ before adding layers") + if not isinstance(layer, base.Layer): + raise TypeError( + "Network.track_layer() passed type %s, not a tf.layers.Layer" % + (type(layer),)) + if isinstance(layer, Network): + layer._finalize_name(parent_network=self) + else: + # `layer` is a non-Network, so it hasn't been named to follow Network + # conventions for contained Layers (i.e. the same conventions as for + # sub-Networks). This renaming is necessary to isolate Network variable + # naming from Layers constructed outside the Network and never added to it + # (because Layers are named globally). + if not layer.built: + if not hasattr(layer, "_first_parent"): + dereferenced_layer_first_parent = None + else: + dereferenced_layer_first_parent = layer._first_parent() + if dereferenced_layer_first_parent is None: + if layer._name != layer._base_name: + # If name and base_name do not match, then this Layer used anonymous + # naming and we have to rename it. Otherwise there's an explicit + # name, and we should respect it (subject to error checking). + layer._name, layer._base_name = layer._make_unique_name( + name_uid_map=self._sub_layer_name_uids, + avoid_names=self._owned_layers) + layer._first_parent = weakref.ref(self) + self._non_network_sublayers.append(layer) + if (not layer.built + and layer._first_parent + and self is layer._first_parent()): + if layer.name in self._owned_layers: + if self._owned_layers[layer.name] is layer: + return layer + raise ValueError( + "Attempt to add two Layers with the name '%s' to the same Network." + % (layer.name)) + self._owned_layers[layer.name] = layer + self._layers.append(layer) + return layer + + def get_layer(self, name=None, index=None): + """Get a contained `tf.layers.Layer` either by name or index. + + Args: + name: String matching one of the names of a contained `Layer`. Note that + the names of `Layer`s added to `Network`s may not be unique when doing + layer sharing (i.e. adding a `Layer` to this `Network` which was already + added to another `Network`). The lowest index `Layer` with a matching + name will be returned. + index: Integer in [0, number of layers). Layers are assigned an index + by the order they are added. + + Returns: + A `tf.layers.Layer` object. + + Raises: + ValueError: If neither or both of 'index' or 'name' is specified, or the + lookup failed. + """ + if index is not None: + if name is not None: + raise ValueError("Exactly one of 'index' or 'name' must be provided") + if len(self._layers) <= index: + raise ValueError("Was asked to retrieve layer at index " + str(index) + + " but model only has " + str(len(self._layers)) + + " layers.") + else: + return self._layers[index] + else: + if not name: + raise ValueError("Provide either a layer name or layer index.") + for layer in self._layers: + if layer.name == name: + return layer + raise ValueError("No such layer: " + name) + + # The following methods are for implementing the Layer interface. + + @property + def weights(self): + # TODO(josh11b): Should this return a set or perform de-duplication of + # variables in the case of shared layers/variables that appear in + # multiple places in the Network? + weights = [] + for layer in self._layers: + weights += layer.weights + return weights + + @property + def trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.trainable_weights + return weights + + @property + def non_trainable_weights(self): + weights = [] + for layer in self._layers: + weights += layer.non_trainable_weights + return weights + + @property + def trainable(self): + return True + + @trainable.setter + def trainable(self, value): + if not value: + # We believe it better to decide which layers & networks are trainable + # at the Trainer level than here. Otherwise you can run into trouble if a + # layer/network is shared between two models, but is trainable in one + # but not the other (like with adversarial networks). + raise AttributeError("cannot mark Network as not trainable") + + @property + def layers(self): + return self._layers + + def add_variable(self, name, shape, dtype=None, initializer=None, + regularizer=None, trainable=True, constraint=None): + raise RuntimeError( + "add_variable not supported in Network class yet. Please file an issue " + "at https://github.com/tensorflow/tensorflow/issues/new if this is " + "important to you") + + def _strip_variable_prefix(self, original_variable_name): + """The default map_func for saving or restoring variables. + + Strips the variable prefix for the Network on which save/restore was called, + and leaves other variable names fully qualified in the checkpoint. + + Args: + original_variable_name: The _shared_name of the variable (no :0 + suffix) to map. + Returns: + The checkpoint name of the variable. + """ + scope_name_with_slash = self.scope_name + "/" + if original_variable_name.startswith(scope_name_with_slash): + return original_variable_name[len(scope_name_with_slash):] + else: + return original_variable_name + + def save(self, save_path, global_step=None, map_func=None): + """Save variables from the Network to a checkpoint. + + Args: + save_path: Either a checkpoint prefix or the name of a directory to save + the checkpoint in (in which case the checkpoint will be named based on + the Network name). + global_step: The global step to use when naming the checkpoint. If None + (default), we will first try to get the default global step. If that + fails because no default global step exists, then the checkpoint is + created without a global step suffix. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. + Returns: + The checkpoint prefix for the saved checkpoint, which may be passed to + `Network.restore`. + Raises: + ValueError: If the Network has not yet been called, or if map_func results + in a name collision. + """ + if not self.built: + raise ValueError( + "Attempt to save the Network before it was first called. This means " + "variables have not yet been created, so there is nothing to save.") + self._set_scope() # scope_name should be available to map_funcs + if global_step is None: + global_step = training_util.get_global_step() + if os.path.isdir(save_path): + # If we were passed a directory, default to naming based on the Network + # name. + save_path = os.path.join(save_path, self.name) + user_map_func = map_func + if map_func is None: + map_func = self._strip_variable_prefix + variable_map = {} + for variable in self.variables: + mapped_name = map_func(variable._shared_name) + if variable_map.setdefault(mapped_name, variable) is not variable: + if user_map_func is None: + # Instead of erroring out, we could just re-try and silently use the + # full variable names in the checkpoint. This could be odd for deeply + # nested sub-Networks (since the full prefix from the nesting would + # get added), so for now we'll let the user deal with this case. + raise ValueError(_default_naming_conflict_error_message( + mapped_name=mapped_name, + first_variable=variable_map[mapped_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + else: + # The user passed their own problematic map_func. + raise ValueError( + ("The map_func passed to Network.save for the Network '%s' " + "resulted in two variables named '%s' ('%s' and '%s'). Try " + "stripping less from the variable names, or renaming parts of " + "the Network. For reference, variables created by sub-Layers of " + "this Network are prefixed with '%s', but if they are re-used " + "after being added to another Network, they will have that " + "Network's full variable prefix instead.") % ( + self.name, mapped_name, + variable_map[mapped_name]._shared_name, + variable._shared_name, + self.scope_name)) + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + return saver_lib.Saver(variable_map).save( + sess=sess, save_path=save_path, write_meta_graph=False, + global_step=global_step) + + def _restore_existing_variables(self, save_path, map_func, user_map_func): + """Use a standard Saver to restore existing variables from a checkpoint. + + Args: + save_path: The checkpoint prefix or directory to read from. + map_func: The function to use when mapping from variable names to + checkpoint names. + user_map_func: The original map_func passed by the user, for error + checking. + Returns: + A dictionary mapping from checkpoint names to variable objects which have + been restored (for bookkeeping to avoid deferred restorations on these + variables). + Raises: + ValueError: If there is a name collision. + """ + existing_variables_by_checkpoint_name = {} + for variable in self.variables: + checkpoint_name = map_func(variable._shared_name) + if existing_variables_by_checkpoint_name.setdefault( + checkpoint_name, variable) is not variable: + if user_map_func is None: + raise ValueError(_default_naming_conflict_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + else: + raise ValueError(_restore_custom_map_func_error_message( + mapped_name=checkpoint_name, + first_variable=existing_variables_by_checkpoint_name[ + checkpoint_name], + second_variable=variable, + network_name=self.name, + network_scope_name=self.scope_name)) + if existing_variables_by_checkpoint_name: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + saver_lib.Saver(var_list=existing_variables_by_checkpoint_name).restore( + sess=sess, save_path=save_path) + return existing_variables_by_checkpoint_name + + def _set_restore_on_create(self, save_path, map_func, user_map_func, + existing_variables_by_checkpoint_name): + """If necessary, request deferred restorations of variables.""" + checkpoint_reader = checkpoint_utils.load_checkpoint(save_path) + checkpointed_variables_to_restore = {} + for checkpoint_name, _ in checkpoint_utils.list_variables(save_path): + if checkpoint_name in existing_variables_by_checkpoint_name: + # This variable was already created and restored. + continue + # Save the variable for later restoration in a custom getter. + checkpointed_variables_to_restore[checkpoint_name] = ( + checkpoint_reader.get_tensor(checkpoint_name)) + # Only set a deferred restoration if there are checkpoint variables which + # have not been assigned to existing variables. Note that this loses out on + # some opportunity for error checking, but avoids creating + # _DeferredRestoration objects once a Network has been built (so that + # restoring in a loop does not take increasing amounts of memory). + if checkpointed_variables_to_restore: + if context.in_eager_mode(): + sess = None + else: + sess = ops.get_default_session() + # We need a name for error messages. If we haven't been added to another + # Network yet, we're top-level. + self._finalize_name(False) + self._set_scope() + # Save a record of this restoration for use in the custom getter. + deferred_restoration = _DeferredRestoration( + map_func=map_func, + map_func_is_user=(user_map_func is not None), + checkpointed_variables_to_restore=checkpointed_variables_to_restore, + restored_variables={}, + session=sess, + network_name=self.name, + network_scope_name=self.scope_name) + self._deferred_restorations.append(deferred_restoration) + # Add the deferred registration to non-Network children, and request that + # Networks propagate the request to their children. + self._add_deferred_restoration(deferred_restoration) + + def _add_deferred_restoration(self, deferred_restoration): + """Add a deferred restoration to this Network and all children. + + Restorations which are requested later have higher priority, and the highest + priority matching restoration is applied to a variable when it is created. + + Args: + deferred_restoration: A _DeferredRestoration object. + """ + # Networks don't create variables at the moment, so this append isn't + # strictly necessary. We could get by with only adding deferred restorations + # to non-Network Layers. + self._set_scope() + # We use set_custom_getter because it avoids recursively calling up the + # variable_scope tree. We've done the tree traversal ourselves and have + # added the request to each Layer which needs it. + self._scope.set_custom_getter(self._custom_getter) + self._deferred_restorations.append(deferred_restoration) + for layer in self.layers: + if isinstance(layer, Network): + # For Networks, request that they propagate this deferred restoration + # to all of their children recursively. + layer._add_deferred_restoration(deferred_restoration) + else: + # For non-Network Layers, make sure they have a deferred restoration + # queue and a custom getter, then add our request to it. + if not hasattr(layer, "_custom_getter"): + assert not hasattr(layer, "_deferred_restorations") + layer._custom_getter, layer._deferred_restorations = ( + _make_custom_getter_for_deferred_restorations()) + self._set_scope_for_nonnetwork_sublayer(layer) + layer._scope.set_custom_getter(layer._custom_getter) + layer._deferred_restorations.append(deferred_restoration) + + def restore(self, save_path, map_func=None): + """Restore the Network from a checkpoint. + + If variables have already been created (typically when some or all of the + `Network` is built), they are assigned values from the checkpoint + immediately, overwriting any existing values (in graph mode the default + session is used for the assignments). + + If there are checkpoint entries which do not correspond to any existing + variables in the `Network`, these values are saved for deferred restoration; + their initial values will be the checkpointed values once they are + created. Requests for multiple deferred restorations behave the same way as + immediate restorations, in that later requests will take priority over + earlier requests relevant to the same variable. + + If this `Network` shares `Layer`s with another network, those `Layer`s will + also have their variables restored from the checkpoint. + + Args: + save_path: The return value of `Network.save`, or a directory to search + for a checkpoint. + map_func: A function mapping fully qualified variable names + (e.g. 'my_network_1/dense_1/kernel') to names in the checkpoint. By + default (if `map_func=None`), the variable prefix for the network being + restored (`Network.scope_name + '/'`, e.g. 'my_network_1/') is stripped + and all other variable names (shared with other Networks) are left + unchanged. Note that this is the _same_ map_func as `Network.save`, not + an inverse mapping. + """ + self._finalize_name(parent_network=False) + self._set_scope() # scope_name should be available to map_funcs + if os.path.isdir(save_path): + # If we don't have a name yet, set no parent. + save_path = os.path.join(save_path, self.name) + user_map_func = map_func + if map_func is None: + map_func = self._strip_variable_prefix + # Step one is to restore any existing variables from the checkpoint. + existing_variables_by_checkpoint_name = self._restore_existing_variables( + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func) + # Step two is to set a custom getter which restores variables on creation, + # for those variables which have not been added to sub-Layers yet. + self._set_restore_on_create( + save_path=save_path, + map_func=map_func, + user_map_func=user_map_func, + existing_variables_by_checkpoint_name=( + existing_variables_by_checkpoint_name)) + + # TODO(josh11b): Support other Layer methods needed for graph mode, such as for + # losses and updates + + +class Sequential(Network): + """Represents a linear sequence of Layers or functions. + + The output of each layer/function is provided as the input to the next. + The inputs passed to `__call__` are passed to the inputs of the first + Layer, and it returns the outputs of the last Layer. + + Args: + layers_funcs: An optional sequence where each element is either a + tf.layers.Layer object or a callable. + name: An optional string name to use for this Network. + """ + + def __init__(self, layers_funcs=None, name=None): + super(Sequential, self).__init__(name=name) + self._layers_funcs = [] + if layers_funcs: + for l in layers_funcs: + self.add(l) + + def add(self, layer_func): + if isinstance(layer_func, base.Layer): + args = estimator_util.fn_args(layer_func.call) + self.track_layer(layer_func) + elif callable(layer_func): + args = estimator_util.fn_args(layer_func) + else: + raise TypeError( + "Sequential.add() takes only tf.layers.Layer objects or callables; " + "not '%s' of type '%s'." % (layer_func, type(layer_func))) + self._layers_funcs.append((("training" in args), layer_func)) + + def call(self, inputs, training=None): + """Call each Layer in the order they were added.""" + # TODO(josh11b): Support "mode" and maybe other arguments + if training is None: + for _, l in self._layers_funcs: + inputs = l(inputs) + else: + for has_training_arg, l in self._layers_funcs: + if has_training_arg: + inputs = l(inputs, training) + else: + inputs = l(inputs) + return inputs diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c621f527c28306131bdba56d8427eaa787ba150b --- /dev/null +++ b/tensorflow/contrib/eager/python/network_test.py @@ -0,0 +1,1075 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gc + +from tensorflow.contrib.eager.python import network +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import errors_impl +from tensorflow.python.framework import test_util +from tensorflow.python.layers import core +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import training_util + + +# pylint: disable=not-callable +class MyNetwork(network.Network): + + def __init__(self, name=None): + super(MyNetwork, self).__init__(name=name) + self.l1 = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.l1(x) + + +class NetworkTest(test.TestCase): + + def _save_modify_load_network_built(self, net, global_step=None): + checkpoint_directory = self.get_temp_dir() + checkpoint_path = net.save( + save_path=checkpoint_directory, global_step=global_step) + input_value = constant_op.constant([[42.0]]) + original_output = self.evaluate(net(input_value)) + for var in net.variables: + self.evaluate(var.assign(var + 1.)) + self.assertGreater( + self.evaluate(net(input_value)), + original_output) + # Either the returned explicit checkpoint path or the directory should work. + net.restore(save_path=checkpoint_directory) + self.assertAllEqual( + original_output, + self.evaluate(net(input_value))) + for var in net.variables: + self.evaluate(var.assign(var + 2.)) + net.restore(save_path=checkpoint_path) + self.assertAllEqual( + original_output, + self.evaluate(net(input_value))) + + @test_util.run_in_graph_and_eager_modes() + def testTrainableAttribute(self): + net = network.Network() + self.assertTrue(net.trainable) + with self.assertRaises(AttributeError): + net.trainable = False + self.assertTrue(net.trainable) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkCall(self): + net = MyNetwork(name="abcd") + net(constant_op.constant([[2.0]])) # Force variables to be created. + self.assertEqual(1, len(net.trainable_variables)) + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + # TODO(josh11b): Support passing Python values to networks. + result = net(constant_op.constant([[2.0]])) + self.assertEqual(34.0, self.evaluate(result)) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkSaveRestoreAlreadyBuilt(self): + net = MyNetwork(name="abcd") + with self.assertRaisesRegexp( + ValueError, "Attempt to save the Network before it was first called"): + net.save(self.get_temp_dir()) + net(constant_op.constant([[2.0]])) + self.evaluate(net.trainable_variables[0].assign([[17.0]])) + self._save_modify_load_network_built(net, global_step=None) + self._save_modify_load_network_built(net, global_step=10) + + @test_util.run_in_graph_and_eager_modes() + def testSaveRestoreDefaultGlobalStep(self): + net = MyNetwork(name="abcd") + net(constant_op.constant([[2.0]])) + self.evaluate(net.variables[0].assign([[3.]])) + default_global_step = training_util.get_or_create_global_step() + self.evaluate(default_global_step.assign(4242)) + save_path = net.save(self.get_temp_dir()) + self.assertIn("abcd-4242", save_path) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkSaveAndRestoreIntoUnbuilt(self): + save_dir = self.get_temp_dir() + net1 = MyNetwork() + test_input = constant_op.constant([[2.0]]) + net1(test_input) + self.evaluate(net1.trainable_variables[0].assign([[17.0]])) + save_path = net1.save(save_dir) + # With a pre-build restore we should have the same value. + net2 = MyNetwork() + net2.restore(save_path) + self.assertAllEqual(self.evaluate(net1(test_input)), + self.evaluate(net2(test_input))) + self.assertIsNot(net1.variables[0], net2.variables[0]) + self.assertAllEqual(self.evaluate(net1.variables[0]), + self.evaluate(net2.variables[0])) + + @test_util.run_in_graph_and_eager_modes() + def testLoadIntoUnbuiltSharedLayer(self): + + class Owner(network.Network): + + def __init__(self, name=None): + super(Owner, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, name="first_layer", use_bias=False)) + + def call(self, x): + return self.first(x) + + first_owner = Owner() + + class User(network.Network): + + def __init__(self, use_layer, name=None): + super(User, self).__init__(name=name) + self.first = self.track_layer(use_layer) + self.second = self.track_layer(core.Dense( + 1, name="second_layer", use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class LikeUserButNotSharing(network.Network): + + def __init__(self, name=None): + super(LikeUserButNotSharing, self).__init__(name=name) + self.first = self.track_layer(core.Dense( + 1, name="first_layer", use_bias=False)) + self.second = self.track_layer(core.Dense( + 1, name="second_layer", use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + checkpoint_creator = LikeUserButNotSharing(name="checkpoint_creator") + one = constant_op.constant([[1.0]]) + checkpoint_creator(one) + self.assertEqual(2, len(checkpoint_creator.variables)) + self.evaluate(checkpoint_creator.variables[0].assign([[5.]])) + self.evaluate(checkpoint_creator.variables[1].assign([[6.]])) + # Re-map the variable names so that with default restore mapping we'll + # attempt to restore into the unbuilt Layer. + name_mapping = { + "checkpoint_creator/first_layer/kernel": "owner_1/first_layer/kernel", + "checkpoint_creator/second_layer/kernel": "second_layer/kernel", + } + save_path = checkpoint_creator.save( + self.get_temp_dir(), + map_func=lambda full_name: name_mapping[full_name]) + load_into = User(use_layer=first_owner.first) + load_into.restore(save_path) + self.assertEqual(0, len(first_owner.variables)) + self.assertAllEqual(self.evaluate(checkpoint_creator(one)), + self.evaluate(load_into(one))) + self.assertEqual(1, len(first_owner.variables)) + self.assertAllEqual([[5.]], self.evaluate(load_into.variables[0])) + self.assertAllEqual([[6.]], self.evaluate(load_into.variables[1])) + first_owner(one) + self.assertAllEqual([[5.]], self.evaluate(first_owner.variables[0])) + + # Try again with a garbage collected parent. + first_owner = Owner() + load_into = User(use_layer=first_owner.first) + del first_owner + gc.collect() + def _restore_map_func(original_name): + if original_name.startswith("owner_1"): + return original_name.replace("owner_1", "owner_2") + else: + return "user_2/" + original_name + with self.assertRaisesRegexp(ValueError, "garbage collected"): + load_into.restore(save_path, map_func=_restore_map_func) + + @test_util.run_in_graph_and_eager_modes() + def testRestoreIntoSubNetwork(self): + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.first(self.second(x)) + + one = constant_op.constant([[3.]]) + whole_model_saver = Parent() + whole_model_saver(one) + self.evaluate(whole_model_saver.variables[0].assign([[15.]])) + self.evaluate(whole_model_saver.variables[1].assign([[16.]])) + whole_model_checkpoint = whole_model_saver.save(self.get_temp_dir()) + + save_from = MyNetwork() + save_from(one) + self.evaluate(save_from.variables[0].assign([[5.]])) + checkpoint = save_from.save(self.get_temp_dir()) + save_into_parent = Parent() + save_into_parent.restore(whole_model_checkpoint) + save_into_parent.first.restore(checkpoint) + save_into_parent.first.restore(checkpoint) # deferred loading multiple + # times is fine + save_into_parent(one) # deferred loading + self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[0])) + self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) + + # Try again with the opposite ordering, and we should get different results + # (deferred restoration should happen the same way non-deferred happens, + # with later restorations overwriting older ones). + save_into_parent = Parent() + save_into_parent.first.restore(checkpoint) # deferred loading multiple + # times is fine + save_into_parent.restore(whole_model_checkpoint) + save_into_parent(one) # deferred loading + # We've overwritten the sub-Network restore. + self.assertAllEqual([[15.]], self.evaluate(save_into_parent.variables[0])) + self.assertAllEqual([[16.]], self.evaluate(save_into_parent.variables[1])) + + self.evaluate(save_into_parent.variables[0].assign([[3.]])) + self.evaluate(save_into_parent.variables[1].assign([[4.]])) + save_into_parent.second.restore(checkpoint) + self.assertAllEqual([[5.]], self.evaluate(save_into_parent.variables[1])) + with self.assertRaisesRegexp(errors_impl.NotFoundError, + "not found in checkpoint"): + # The checkpoint is incompatible. + save_into_parent.restore(checkpoint) + + @test_util.run_in_graph_and_eager_modes() + def testCustomMapCollisionErrors(self): + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.first(self.second(x)) + + make_checkpoint = Parent() + one = constant_op.constant([[1.]]) + make_checkpoint(one) + self.evaluate(make_checkpoint.variables[0].assign([[2.]])) + self.evaluate(make_checkpoint.variables[1].assign([[3.]])) + with self.assertRaisesRegexp( + ValueError, + "The map_func passed to Network.save for the Network 'parent_1' " + "resulted in two variables named 'foo'"): + make_checkpoint.save(self.get_temp_dir(), map_func=lambda n: "foo") + checkpoint = make_checkpoint.first.save( + self.get_temp_dir(), map_func=lambda n: "foo") + loader = Parent() + loader.restore(checkpoint, map_func=lambda n: "foo") + with self.assertRaisesRegexp( + ValueError, + ("The map_func passed to Network.restore for the Network" + " 'parent_2' resulted in two variables named 'foo'")): + loader(one) + loader = Parent() + loader(one) + with self.assertRaisesRegexp( + ValueError, + ("The map_func passed to Network.restore for the Network" + " 'parent_3' resulted in two variables named 'foo'")): + loader.restore(checkpoint, map_func=lambda n: "foo") + + @test_util.run_in_graph_and_eager_modes() + def testDefaultMapCollisionErrors(self): + + one = constant_op.constant([[1.]]) + first = core.Dense(1, name="dense_1", use_bias=False) + first(one) + + class Parent(network.Network): + + def __init__(self, name=None): + super(Parent, self).__init__(name=name) + self.first = self.track_layer(first) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(self.second(x)) + + make_checkpoint = Parent() + one = constant_op.constant([[1.]]) + make_checkpoint(one) + self.evaluate(make_checkpoint.variables[0].assign([[2.]])) + self.evaluate(make_checkpoint.variables[1].assign([[3.]])) + with self.assertRaisesRegexp( + ValueError, + ("The default checkpoint variable name mapping strategy for Network " + "'parent_1' resulted in a naming conflict.")): + make_checkpoint.save(self.get_temp_dir()) + + class Compatible(network.Network): + + def __init__(self, name=None): + super(Compatible, self).__init__(name=name) + self.first = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(x) + + successful_checkpoint = Compatible() + successful_checkpoint(one) + self.evaluate(successful_checkpoint.variables[0].assign([[-1.]])) + checkpoint_path = successful_checkpoint.save(self.get_temp_dir()) + load_checkpoint = Parent() + load_checkpoint(one) + with self.assertRaisesRegexp( + ValueError, + ("The default checkpoint variable name mapping strategy for Network " + "'parent_2' resulted in a naming conflict.")): + load_checkpoint.restore(checkpoint_path) + + def testNoReferenceCyclesAfterCall(self): + + class ChildNetwork(network.Network): + + def __init__(self, name=None): + super(ChildNetwork, self).__init__(name=name) + + def call(self, x): + return x * 2. + + class ParentNetwork(network.Network): + + def __init__(self, name=None): + super(ParentNetwork, self).__init__(name=name) + self.l1 = self.track_layer(ChildNetwork()) + + def call(self, x): + return self.l1(x) + + one = constant_op.constant([[1.0]]) + gc.disable() + gc.collect() + previous_gc_debug_flags = gc.get_debug() + gc.set_debug(gc.DEBUG_SAVEALL) + preexisting = len(gc.garbage) + net = ParentNetwork() + net(one) + del net + gc.collect() + # There should be no additional garbage requiring collection. + self.assertEqual(preexisting, len(gc.garbage)) + gc.set_debug(previous_gc_debug_flags) + gc.enable() + + @test_util.run_in_graph_and_eager_modes() + def testAnonymousNoNameInitially(self): + net = MyNetwork() + with self.assertRaisesRegexp(ValueError, "does not yet have a final name"): + net.name # pylint: disable=pointless-statement + + @test_util.run_in_graph_and_eager_modes() + def testExplicitHasNameInitially(self): + net = MyNetwork(name="abcd") + self.assertEqual("abcd", net.name) + + @test_util.run_in_graph_and_eager_modes() + def testUsingResourceVariables(self): + net = MyNetwork() + net(constant_op.constant([[0.]])) + self.assertIsInstance(net.trainable_weights[0], + resource_variable_ops.ResourceVariable) + + @test_util.run_in_graph_and_eager_modes() + def testDuplicateNameError(self): + one = constant_op.constant([[1.]]) + net = MyNetwork(name="foo") + net(one) + with self.assertRaisesRegexp( + ValueError, "named 'foo' already exists"): + net1 = MyNetwork(name="foo") + net1(one) + + @test_util.run_in_graph_and_eager_modes() + def testWrappingInVariableScope(self): + with variable_scope.variable_scope("outside_scope"): + net = MyNetwork() + one = constant_op.constant([[1.]]) + with self.assertRaisesRegexp( + ValueError, + ("Creating Networks inside named variable_scopes is currently not " + "supported")): + net(one) + # Alternatively, we could re-name the Network to match the variable_scope: + # self.assertEqual("outside_scope/my_network_1", net.name) + # self.assertStartsWith( + # expected_start="outside_scope/my_network_1/dense/", + # actual=net.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerNamesRespected(self): + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer( + core.Dense(1, use_bias=False, name="explicit_name")) + + def call(self, x): + return self.first(x) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith(expected_start="parent_network_1/explicit_name/", + actual=net.trainable_weights[0].name) + self.assertEqual("explicit_name", net.first.name) + + @test_util.run_in_graph_and_eager_modes() + def testWrappingInAnonymousVariableScope(self): + # Named outside variable_scopes are not supported at the moment. However, + # blank-named top level variable scopes do not change variable names, and so + # can be used to set the properties of Network variables. + was_called = [False] + def _custom_getter(getter, *args, **kwargs): + was_called[0] = True + return getter(*args, **kwargs) + with variable_scope.variable_scope("", custom_getter=_custom_getter): + net = MyNetwork() + one = constant_op.constant([[1.]]) + net(one) + self.assertTrue(was_called[0]) + + @test_util.run_in_graph_and_eager_modes() + def testReasonableSlashError(self): + with self.assertRaisesRegexp( + ValueError, "not allowed in Network names"): + MyNetwork(name="slash/slash") + + @test_util.run_in_graph_and_eager_modes() + def testNoVariableScopeNames(self): + with self.assertRaisesRegexp( + ValueError, "VariableScopes are not valid Network names"): + with variable_scope.variable_scope("some_scope") as vs: + MyNetwork(name=vs) + + @test_util.run_in_graph_and_eager_modes() + def testVariableScopeNameCollision(self): + with variable_scope.variable_scope("abcd"): + pass + with self.assertRaisesRegexp( + ValueError, "or a variable_scope was created with this name"): + net = MyNetwork(name="abcd") + one = constant_op.constant([[1.]]) + net(one) + + @test_util.run_in_graph_and_eager_modes() + def testNetworkVariablesDoNotInterfere(self): + core.Dense(1, use_bias=True) # Should not interfere with naming. + net1 = MyNetwork() + net2 = MyNetwork() + one = constant_op.constant([[1.]]) + net1(one) + net2(one) + # Layer names typically are globally unique rather than being unique within + # the scope of their first use. However, within a Network they must be named + # locally so that previous Layer consutrciton does not interfere with + # variable naming (e.g. add a Layer construction before the Network, + # suddenly your previously saved checkpoint is incompatible). + self.assertEqual("dense_1", net1.l1.name) + self.assertEqual("dense_1", net2.l1.name) + self.evaluate(net1.trainable_weights[0].assign([[1.]])) + self.evaluate(net2.trainable_weights[0].assign([[2.]])) + self.assertEqual(2., self.evaluate(net2.trainable_weights[0])) + self.assertEqual(1., self.evaluate(net1.trainable_weights[0])) + self.assertStartsWith(expected_start="my_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith(expected_start="my_network_2/dense_1/", + actual=net2.trainable_weights[0].name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableAnonymous(self): + + # The case where no explicit names are specified. We make up unique names, + # and these should match the variable names. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_1/dense", + actual=net.first.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + actual=net.trainable_weights[1].name) + self.assertStartsWith(expected_start="parent_network_1/my_network_2/dense", + actual=net.second.trainable_weights[0].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("my_network_1", net.first.name) + self.assertEqual("my_network_2", net.second.name) + + net2 = ParentNetwork() + net2(one) + self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + actual=net2.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_1/dense", + actual=net2.first.trainable_weights[0].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + actual=net2.trainable_weights[1].name) + self.assertStartsWith(expected_start="parent_network_2/my_network_2/dense", + actual=net2.second.trainable_weights[0].name) + self.assertEqual("parent_network_2", net2.name) + self.assertEqual("my_network_1", net2.first.name) + self.assertEqual("my_network_2", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicit(self): + + # We have explicit network names and everything is globally unique. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="unique_parent_name") + self.first = self.track_layer( + MyNetwork(name="first_unique_child_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="unique_parent_name/first_unique_child_name/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="unique_parent_name/second_unique_child_name/dense", + actual=net.trainable_weights[1].name) + self.assertEqual("unique_parent_name", net.name) + self.assertEqual("first_unique_child_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerNetworkNameInteractions(self): + + # Same base name as core.Dense; Networks and non-Network Layers with the + # same base name should use the same numbering system. + class Dense(network.Network): + + def __init__(self): + super(Dense, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.first(x) + + class MixedLayerNetwork(network.Network): + + def __init__(self): + super(MixedLayerNetwork, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + self.third = self.track_layer(Dense()) + self.fourth = self.track_layer(core.Dense(1, use_bias=False)) + self.fifth = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.fifth(self.fourth(self.third(self.second(self.first(x))))) + + one = constant_op.constant([[1.]]) + net = MixedLayerNetwork() + net(one) + self.assertEqual("dense_1", net.first.name) + self.assertEqual("dense_2", net.second.name) + self.assertEqual("dense_3", net.third.name) + self.assertEqual("dense_4", net.fourth.name) + self.assertEqual("dense_5", net.fifth.name) + # Note that this is _not_ the default naming behavior for Layers. Layers + # which are added to Networks follow Network variable naming conventions + # (i.e. variable names = network name unless variable sharing). Nested + # Layers revert to Layer behavior. + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_1/", + actual=net.trainable_weights[0].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_2/", + actual=net.trainable_weights[1].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_3/", + actual=net.trainable_weights[2].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_4/", + actual=net.trainable_weights[3].name) + self.assertStartsWith(expected_start="mixed_layer_network_1/dense_5/", + actual=net.trainable_weights[4].name) + self.assertEqual("mixed_layer_network_1", net.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitCollisions(self): + + # We have explicit network names and they are unique within the layer + # they're added to. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="nonunique_name") + self.first = self.track_layer( + MyNetwork(name="nonunique_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="nonunique_name/nonunique_name/dense", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="nonunique_name/second_unique_child_name/dense", + actual=net.trainable_weights[1].name) + self.assertEqual("nonunique_name", net.name) + self.assertEqual("nonunique_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitWithAnonymousParent(self): + + # A parent network is instantiated multiple times with explicitly named + # children. We shouldn't throw any name errors. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer( + MyNetwork(name="first_unique_child_name")) + self.second = self.track_layer( + MyNetwork(name="second_unique_child_name")) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = ParentNetwork() + net(one) + self.assertStartsWith( + expected_start="parent_network_1/first_unique_child_name/dense_1/", + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="parent_network_1/second_unique_child_name/dense_1/", + actual=net.trainable_weights[1].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("first_unique_child_name", net.first.name) + self.assertEqual("second_unique_child_name", net.second.name) + + net2 = ParentNetwork() + net2(one) + self.assertStartsWith( + expected_start="parent_network_2/first_unique_child_name/dense", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="parent_network_2/second_unique_child_name/dense", + actual=net2.trainable_weights[1].name) + self.assertEqual("parent_network_2", net2.name) + self.assertEqual("first_unique_child_name", net2.first.name) + self.assertEqual("second_unique_child_name", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testNestableExplicitSameLayerCollisions(self): + + # We have explicit network names and they are _not_ unique within the layer + # they're added to. Error. + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__(name="unique_parent_name") + self.first = self.track_layer(MyNetwork(name="nonunique_name")) + self.second = self.track_layer(MyNetwork(name="nonunique_name")) + + def call(self, x): + return self.second(self.first(x)) + + with self.assertRaisesRegexp(ValueError, "nonunique_name"): + ParentNetwork() + + @test_util.run_in_graph_and_eager_modes() + def testAnonymousVariableSharing(self): + + # Two "owned" Networks + class FirstParentNetwork(network.Network): + + def __init__(self): + super(FirstParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + one = constant_op.constant([[1.]]) + net = FirstParentNetwork() + net(one) + + # One Network shared with FirstParentNetwork, one owned Network. Same name, + # but this is OK because only one is owned. This name collision is + # avoidable; we could have looked at the base_name of the non-owned Network + # and incremented our naming based on that. + class SecondParentNetwork(network.Network): + + def __init__(self): + super(SecondParentNetwork, self).__init__() + self.first = self.track_layer(net.first) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + net2 = SecondParentNetwork() + net2(one) + + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_parent_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertEqual("second_parent_network_1", net2.name) + self.assertTrue(net2.first is net.first) + self.assertEqual("my_network_1", net2.first.name) + self.assertEqual("my_network_1", net2.second.name) + + # No name collision; the owned Network is added first and has a different + # name than the shared Network. + class ThirdParentNetwork(network.Network): + + def __init__(self): + super(ThirdParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(net.second) + + def call(self, x): + return self.second(self.first(x)) + + net3 = ThirdParentNetwork() + net3(one) + + self.assertStartsWith( + expected_start="third_parent_network_1/my_network_1/dense", + actual=net3.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_2/dense", + actual=net3.trainable_weights[1].name) + self.assertEqual("third_parent_network_1", net3.name) + self.assertTrue(net3.second is net.second) + self.assertEqual("my_network_1", net3.first.name) + self.assertEqual("my_network_2", net3.second.name) + + # "Unavoidable" same-name Layer. The owned name is added first (fixed), then + # a shared Network is added with the same name. + class FourthParentNetwork(network.Network): + + def __init__(self): + super(FourthParentNetwork, self).__init__() + self.first = self.track_layer(MyNetwork()) + self.second = self.track_layer(net.first) + + def call(self, x): + return self.second(self.first(x)) + + net4 = FourthParentNetwork() + net4(one) + + self.assertStartsWith( + expected_start="fourth_parent_network_1/my_network_1/dense_1/", + actual=net4.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_parent_network_1/my_network_1/dense_1/", + actual=net4.trainable_weights[1].name) + self.assertEqual("fourth_parent_network_1", net4.name) + self.assertTrue(net4.second is net.first) + self.assertEqual("my_network_1", net4.first.name) + self.assertEqual("my_network_1", net4.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testRecursiveLayerRenaming(self): + core.Dense(1) # Under default Layer naming, would change subsequent names. + + class NetworkWithLayerChildren(network.Network): + + def __init__(self): + super(NetworkWithLayerChildren, self).__init__() + self.first = self.track_layer(core.Dense(1, use_bias=False)) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class ParentNetwork(network.Network): + + def __init__(self): + super(ParentNetwork, self).__init__() + self.first = self.track_layer(NetworkWithLayerChildren()) + self.second = self.track_layer(NetworkWithLayerChildren()) + + def call(self, x): + return self.second(self.first(x)) + + net = ParentNetwork() + one = constant_op.constant([[1.]]) + net(one) + + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_1/" + "dense_1/"), + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_1/" + "dense_2/"), + actual=net.trainable_weights[1].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_2/" + "dense_1/"), + actual=net.trainable_weights[2].name) + self.assertStartsWith( + expected_start=("parent_network_1/network_with_layer_children_2/" + "dense_2/"), + actual=net.trainable_weights[3].name) + self.assertEqual("parent_network_1", net.name) + self.assertEqual("network_with_layer_children_1", net.first.name) + self.assertEqual("network_with_layer_children_2", net.second.name) + self.assertEqual("dense_1", net.first.first.name) + self.assertEqual("dense_2", net.first.second.name) + self.assertEqual("dense_1", net.second.first.name) + self.assertEqual("dense_2", net.second.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testCallInDifferentOrderThanConstruct(self): + shared_network = MyNetwork() + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_network) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + class SecondNetwork(network.Network): + + def __init__(self): + super(SecondNetwork, self).__init__() + self.first = self.track_layer(shared_network) + self.second = self.track_layer(MyNetwork()) + + def call(self, x): + return self.second(self.first(x)) + + net1 = FirstNetwork() + net2 = SecondNetwork() + + one = constant_op.constant([[1.]]) + net2(one) + net1(one) + + self.assertStartsWith( + expected_start="first_network_1/my_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/my_network_2/dense_1/", + actual=net1.trainable_weights[1].name) + self.assertStartsWith( + expected_start="first_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_network_1/my_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) + self.assertEqual("first_network_1", net1.name) + self.assertEqual("my_network_1", net1.first.name) + self.assertEqual("my_network_2", net1.second.name) + self.assertTrue(net2.first is net1.first) + self.assertEqual("my_network_1", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerCallInDifferentOrderThanConstruct(self): + # Same idea as testCallInDifferentOrderThanConstruct, but this time with a + # non-Network Layer shared between two Networks rather than a + # Network. Naming should follow the same rules. + shared_layer = core.Dense(1, use_bias=False) + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + class SecondNetwork(network.Network): + + def __init__(self): + super(SecondNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net1 = FirstNetwork() + net2 = SecondNetwork() + + one = constant_op.constant([[1.]]) + net2(one) + net1(one) + + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net1.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/dense_2/", + actual=net1.trainable_weights[1].name) + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net2.trainable_weights[0].name) + self.assertStartsWith( + expected_start="second_network_1/dense_1/", + actual=net2.trainable_weights[1].name) + self.assertTrue(net1.trainable_weights[0] is net2.trainable_weights[0]) + self.assertEqual("first_network_1", net1.name) + self.assertEqual("dense_1", net1.first.name) + self.assertEqual("dense_2", net1.second.name) + self.assertTrue(net2.first is net1.first) + self.assertEqual("dense_1", net2.second.name) + + @test_util.run_in_graph_and_eager_modes() + def testLayerAlreadyBuilt(self): + one = constant_op.constant([[1.]]) + core.Dense(1, use_bias=False) # pre-built layers use global naming + one = constant_op.constant([[1.]]) + core.Dense(1, use_bias=False)(one) + shared_layer = core.Dense(1, use_bias=False) + shared_layer(one) + + class FirstNetwork(network.Network): + + def __init__(self): + super(FirstNetwork, self).__init__() + self.first = self.track_layer(shared_layer) + self.second = self.track_layer(core.Dense(1, use_bias=False)) + + def call(self, x): + return self.second(self.first(x)) + + net = FirstNetwork() + net(one) + + self.assertStartsWith( + expected_start="dense_1/", # Pre-built layers have variable names which + # do not match their layer names. + actual=net.trainable_weights[0].name) + self.assertStartsWith( + expected_start="first_network_1/dense_1/", + actual=net.trainable_weights[1].name) + self.assertTrue( + net.trainable_weights[0] is shared_layer.trainable_weights[0]) + self.assertEqual("first_network_1", net.name) + self.assertEqual("dense_3", net.first.name) + self.assertEqual("dense_1", net.second.name) + + +class SequentialTest(test.TestCase): + + def testTwoLayers(self): + # Create a sequential network with one layer. + net = network.Sequential([core.Dense(1, use_bias=False)]) + + # Set that layer's weights so it multiplies by 3 + l1 = net.get_layer(index=0) + net(constant_op.constant([[2.0]])) # Create l1's variables + self.assertEqual(1, len(l1.trainable_variables)) + l1.trainable_variables[0].assign([[3.0]]) + self.assertEqual(21.0, net(constant_op.constant([[7.0]])).numpy()) + + # Add a second layer to the network. + l2 = core.Dense(1, use_bias=False) + net.add(l2) + + # Set the second layer's weights so it multiplies by 11 + net(constant_op.constant([[2.0]])) # Create l2's variables + self.assertEqual(1, len(l2.trainable_variables)) + l2.trainable_variables[0].assign([[11.0]]) + self.assertEqual(231.0, net(constant_op.constant([[7.0]])).numpy()) + + def testFunctions(self): + # Create a sequential network with one function. + net = network.Sequential([nn_ops.relu]) + two = constant_op.constant(2.0) + self.assertEqual(2.0, net(two).numpy()) + self.assertEqual(0.0, net(-two).numpy()) + # Add a second function. + net.add(math_ops.negative) + self.assertEqual(-2.0, net(two).numpy()) + + def testTrainingLayer(self): + net = network.Sequential([core.Dropout(0.99999)]) + two = constant_op.constant(2.0) + self.assertEqual(2.0, net(two).numpy()) + self.assertEqual(2.0, net(two, training=False).numpy()) + for _ in range(20): + with_dropout = net(two, training=True).numpy() + self.assertIn(with_dropout, [0.0, 2.0]) + if with_dropout == 0.0: + return + # Should only fail spuriously 1 in 10^100 runs. + self.fail("Didn't see dropout happen after 20 tries.") + + def testTrainingFunction(self): + # Output depends on value of "training". + def add_training(input_value, training=None): + if training is None: + return input_value + elif training: + return input_value + 1 + return input_value - 1 + + # Passing a "training" argument to double would cause an error. + def double(input_value): + return 2 * input_value + + net = network.Sequential([add_training, double]) + two = constant_op.constant(2) + self.assertEqual(4, net(two).numpy()) + self.assertEqual(2, net(two, training=False).numpy()) + self.assertEqual(6, net(two, training=True).numpy()) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/saver.py b/tensorflow/contrib/eager/python/saver.py index 8edd4b816397cf4ba5f7d43b78f6e50ee6619da1..57b070ec6eeac00c77f199a846639d64c4957cd8 100644 --- a/tensorflow/contrib/eager/python/saver.py +++ b/tensorflow/contrib/eager/python/saver.py @@ -19,7 +19,9 @@ from __future__ import print_function import contextlib +from tensorflow.python.eager import context from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.training import checkpoint_utils from tensorflow.python.training import saver as _saver @@ -27,97 +29,153 @@ from tensorflow.python.training import saver as _saver def _init_from_checkpoint(self, *args, **kwargs): """Overrides default init by loading value from checkpoint.""" - self.old_init(*args, **kwargs) # pylint: disable=protected-access - if self._shared_name not in self.ckpt_var_cache: + self._old_init(*args, **kwargs) + ckpt_name = self._map_func(self._shared_name) + if ckpt_name not in self._ckpt_var_cache: raise errors.NotFoundError(None, None, - "%s not found in checkpoint" % self._shared_name) + "%s not found in checkpoint" % ckpt_name) - val = self.ckpt_var_cache[self._shared_name] + val = self._ckpt_var_cache.get(ckpt_name, None) if val is not None: - self.assign(self.ckpt_var_cache[self._shared_name]) + self.assign(val) # Avoid assigning for the second time. - self.ckpt_var_cache[self._shared_name] = None + self._ckpt_var_cache[ckpt_name] = None # pylint: enable=protected-access -class Saver(object): - """A simple tf.train.Saver adapter for eager mode. +@contextlib.contextmanager +def restore_variables_on_create(save_path, map_func=None): + """ContextManager that restores variables on creation. - save and restore API are similar to the tf.train.Saver, except that - session is not needed. + When save_path is None (e.g. No checkpoint), does nothing. + Otherwise, it preloads all values from checkpoint. When the + corresponding variable is first created, it assigns the checkpoint + value to the variable. - restore_on_create is eager mode's way to reload checkpoint value during - the execution. (unlike graph mode's reload before run). + ```python + with restore_variables_on_create( + tf.train.latest_checkpoint(checkpoint_dir)): + ``` Args: - var_list: See tf.train.Saver. Works the same for save/restore. Ignored - by restore_on_create. + save_path: The checkpoint file prefix. + map_func: A function that given the variable name as argument + and returns a variable name in checkpoint for restore. If + None, use the variable with the same name in checkpoint to restore. + It's an error that the mapped variable name doesn't exist in + checkpoint. + + Yields: + Nothing. + + Raises: + NotFoundError: If the variable is not found in checkpoint. + ValueError: If not used in eager mode or map_func is not callable. """ + if context.in_graph_mode(): + raise ValueError( + "Currently, restore_variables_on_create can only be used with " + "eager execution enabled.") + if save_path: + if map_func is None: + map_func_wrapper = lambda self, x: x + else: + if not callable(map_func): + raise ValueError("map_func must be callaled.") + map_func_wrapper = lambda self, x: map_func(x) + + ckpt_var_cache = dict() + reader = checkpoint_utils.load_checkpoint(save_path) + for k, _ in checkpoint_utils.list_variables(save_path): + ckpt_var_cache[k] = reader.get_tensor(k) + + old_init = getattr(resource_variable_ops.ResourceVariable, + "_init_from_args", None) + assert old_init, "ResourceVariable misses _init_from_args method." + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + _init_from_checkpoint) + setattr(resource_variable_ops.ResourceVariable, "_old_init", old_init) + setattr(resource_variable_ops.ResourceVariable, "_map_func", + map_func_wrapper) + setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", + ckpt_var_cache) + try: + yield + except Exception as e: + raise e + finally: + if save_path: + setattr(resource_variable_ops.ResourceVariable, "_init_from_args", + old_init) + setattr(resource_variable_ops.ResourceVariable, "_old_init", None) + setattr(resource_variable_ops.ResourceVariable, "_map_func", None) + setattr(resource_variable_ops.ResourceVariable, "_ckpt_var_cache", None) + + +class Saver(object): + """A tf.train.Saver adapter for use when eager execution is enabled. + """ + + def __init__(self, var_list): + """A tf.train.Saver adapter for use when eager execution is enabled. + + The API, and on-disk format, mimic tf.train.Saver except that no + Session is needed. + + Args: + var_list: The list of variables that will be saved and restored. Either a + list of `tfe.Variable` objects, or a dictionary mapping names to + `tfe.Variable` objects. - def __init__(self, var_list=None): + Raises: + RuntimeError: if invoked when eager execution has not been enabled. + """ + if context.in_graph_mode(): + raise RuntimeError("tfe.Saver can only be used when eager " + "execution is enabled. Use tf.train.Saver when " + "building graphs.") self._saver = _saver.Saver(var_list=var_list) - def save(self, save_path, global_step=None): + def save(self, file_prefix, global_step=None): """Saves variables. Args: - save_path: See save method in tf.train.Saver. - global_step: See save method in tf.train.Saver. + file_prefix: Path prefix of files created for the checkpoint. + global_step: If provided the global step number is appended to file_prefix + to create the checkpoint filename. The optional argument can be a + Tensor, a Variable, or an integer. Returns: - See save method in tf.train.Saver. + A string: prefix of filenames created for the checkpoint. This may be + an extension of file_prefix that is suitable to pass as an argument + to a subsequent call to `restore()`. """ - return self._saver.save(None, save_path, write_meta_graph=False, - global_step=global_step) + with ops.device("/device:CPU:0"): + return self._saver.save( + None, file_prefix, write_meta_graph=False, global_step=global_step) - def restore(self, save_path): + def restore(self, file_prefix): """Restores previously saved variables. Args: - save_path: See restore method in tf.train.Saver. + file_prefix: Path prefix where parameters were previously saved. + Typically obtained from a previous `save()` call, or from + @{tf.train.latest_checkpoint}. """ - self._saver.restore(None, save_path) + with ops.device("/device:CPU:0"): + self._saver.restore(None, file_prefix) - @contextlib.contextmanager - def maybe_restore_on_create(self, save_path): - """ContextManager that restores variables on creation. - When save_path is None (e.g. No checkpoint), does nothing. - Otherwise, it preloads all values from checkpoint. When the - corresponding variable is first created, it assigns the checkpoint - value to the variable. +def get_optimizer_variables(optimizer): + """Returns a list of variables for the given `tf.train.Optimizer`. - Args: - save_path: Same as save_path of retore. If None, do not restore. + Equivalent to `optimizer.variables()`. - Yields: - Nothing. - - Raises: - NotFoundError: If the variable is not found in checkpoint. - """ - if save_path: - ckpt_var_cache = dict() - reader = checkpoint_utils.load_checkpoint(save_path) - for k, _ in checkpoint_utils.list_variables(save_path): - ckpt_var_cache[k] = reader.get_tensor(k) - - old_init = getattr( - resource_variable_ops.ResourceVariable, "_init_from_args", None) - assert old_init, "ResourceVariable misses _init_from_args method." - setattr(resource_variable_ops.ResourceVariable, "_init_from_args", - _init_from_checkpoint) - setattr(resource_variable_ops.ResourceVariable, "old_init", old_init) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", - ckpt_var_cache) - try: - yield - except Exception as e: - raise e - finally: - if save_path: - setattr(resource_variable_ops.ResourceVariable, "_init_from_args", - old_init) - setattr(resource_variable_ops.ResourceVariable, "old_init", None) - setattr(resource_variable_ops.ResourceVariable, "ckpt_var_cache", None) + Args: + optimizer: An instance of `tf.train.Optimizer` which has created variables + (typically after a call to `Optimizer.minimize`). + Returns: + A list of variables which have been created by the `Optimizer`. + """ + return optimizer.variables() diff --git a/tensorflow/contrib/eager/python/saver_test.py b/tensorflow/contrib/eager/python/saver_test.py index 9c8294e3bacc2c6fe2689d81cdf6efa7f8ddbc4b..abc7e3690c76c4446bce6b945325f1ca15ef1c8b 100644 --- a/tensorflow/contrib/eager/python/saver_test.py +++ b/tensorflow/contrib/eager/python/saver_test.py @@ -21,17 +21,28 @@ import os from tensorflow.contrib.eager.python import saver as _saver from tensorflow.python.eager import context +from tensorflow.python.eager import graph_callable +from tensorflow.python.eager import test +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops from tensorflow.python.ops import resource_variable_ops -from tensorflow.python.platform import test +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import adam +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import momentum +from tensorflow.python.training import rmsprop class SaverTest(test.TestCase): + def _dev(self): + return '/device:GPU:0' if context.num_gpus() else '/device:CPU:0' + def testBasics(self): - with context.eager_mode(): + with ops.device(self._dev()): v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') def model(): return array_ops.constant(2.0) * v1 @@ -47,8 +58,76 @@ class SaverTest(test.TestCase): saver.restore(ckpt_prefix) self.assertEqual(v1.read_value().numpy(), 1.0) + def testSameNameNoClobbering(self): + with ops.device(self._dev()): + # Note that this test purposefully uses Graphs rather than + # IsolateTest. Users are more likely to accidentally create the same + # variable name this way. + first_graph = ops.Graph() + with first_graph.as_default(): + v1_first_graph = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + v1_second_graph = resource_variable_ops.ResourceVariable(2.0, name='v1') + saver = _saver.Saver([v1_first_graph, v1_second_graph]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'v1'): + saver.save(ckpt_prefix) + + def testDifferentGraphError(self): + with ops.device(self._dev()): + with ops.Graph().as_default(): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + with ops.Graph().as_default(): + saver = _saver.Saver([v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + with self.assertRaisesRegexp(ValueError, 'Graph'): + saver.save(ckpt_prefix) + + def testSameObjectOK(self): + with ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + # While different objects with the same shared_name are not good, passing + # in the same object multiple times is fine. + saver = _saver.Saver([v1, v1]) + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + saver.save(ckpt_prefix) + + def testSaveByDict(self): + with ops.device(self._dev()): + v1 = resource_variable_ops.ResourceVariable(1.0, name='v1') + v2 = resource_variable_ops.ResourceVariable(1.0, name='v2') + def model(): + return array_ops.constant(2.0) * v1 * v2 + + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + + # Save the variables under different names. + _ = model() + saver = _saver.Saver({'ckpt/v1': v1, 'ckpt/v2': v2}) + saver.save(ckpt_prefix) + v1.assign(2.0) + v2.assign(2.0) + self.assertEqual(v1.read_value().numpy(), 2.0) + self.assertEqual(v2.read_value().numpy(), 2.0) + # Can still restore it. + saver.restore(ckpt_prefix) + self.assertEqual(v1.read_value().numpy(), 1.0) + self.assertEqual(v1.read_value().numpy(), 1.0) + # However, cannot restore it with default name. + with self.assertRaisesOpError('not found in checkpoint'): + saver = _saver.Saver([v1, v2]).restore(ckpt_prefix) + + # Can specify which variable in ckpt to restore to which variable. + def map_func(x): + return {'v3': 'ckpt/v1', 'v4': 'ckpt/v2'}.get(x, x) + with _saver.restore_variables_on_create(ckpt_prefix, map_func): + v3 = resource_variable_ops.ResourceVariable(2.0, name='v3') + v4 = resource_variable_ops.ResourceVariable(2.0, name='v4') + self.assertEqual(v3.read_value().numpy(), 1.0) + self.assertEqual(v4.read_value().numpy(), 1.0) + def testRestoreOnCreate(self): - with context.eager_mode(): + with ops.device(self._dev()): def model(init_val): v1 = resource_variable_ops.ResourceVariable(init_val, name='v1') return array_ops.constant(1.0) * v1, v1 @@ -60,16 +139,13 @@ class SaverTest(test.TestCase): with ops.Graph().as_default(): saver = _saver.Saver([v1]) - with saver.maybe_restore_on_create(ckpt_prefix): + with _saver.restore_variables_on_create(ckpt_prefix): # Value is from checkpoint, but not from argument. ret, _ = model(2.0) self.assertEqual(ret.numpy(), 1.0) - # Create it a second time won't re-assign the checkpoint value. - v1_2 = resource_variable_ops.ResourceVariable(3.0, name='v1') - self.assertEqual(v1_2.read_value().numpy(), 3.0) def testRestoreNotFound(self): - with context.eager_mode(): + with ops.device(self._dev()): def model(v): return array_ops.constant(1.0) * v @@ -81,9 +157,93 @@ class SaverTest(test.TestCase): with self.assertRaisesRegexp(errors.NotFoundError, 'v2 not found in checkpoint'): - with saver.maybe_restore_on_create(ckpt_prefix): + with _saver.restore_variables_on_create(ckpt_prefix): _ = model(resource_variable_ops.ResourceVariable(1.0, name='v2')) + def testSaveRestoreGraphCallable(self): + with ops.device(self._dev()): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + # Default 2 + 0 = 2 + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # Save the variable value 0. + ckpt_prefix = os.path.join(test.get_temp_dir(), 'ckpt') + _saver.Saver(model.variables).save(ckpt_prefix) + + # update variable to 1, so that 2 + 1 = 3 + model.variables[0].assign(1.) + self.assertEqual( + 3, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # load the variable value 0, so that 2 + 0 = 2 + _saver.Saver(model.variables).restore(ckpt_prefix) + self.assertEqual( + 2, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # update checkpoint variable to 1 and memory value to 2. + model.variables[0].assign(1.) + _saver.Saver(model.variables).save(ckpt_prefix) + model.variables[0].assign(2.) + self.assertEqual( + 4, model(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + # reset the graph and reload on create, so that 1 + 2 = 3 + with ops.Graph().as_default(): + with _saver.restore_variables_on_create(ckpt_prefix): + @graph_callable.graph_callable( + [graph_callable.ShapeAndDtype(shape=(), dtype=dtypes.float32)]) + def model2(x): + v = variable_scope.get_variable( + 'v', initializer=init_ops.zeros_initializer(), shape=()) + return v + x + + self.assertEqual( + 3, model2(array_ops.constant(2, dtype=dtypes.float32)).numpy()) + + +class GetOptimizerTests(test.TestCase): + + def _optimizer_test_template(self, optimizer): + """Checks save and restore. Returns the optimizer variables.""" + v = resource_variable_ops.ResourceVariable([[2., 3.]], name='v') + loss_fn = lambda: v[0, 0] ** 2 + v[0, 1] ** 2 + optimizer.minimize(loss_fn) + optimizer_variables = _saver.get_optimizer_variables(optimizer) + saver = _saver.Saver(optimizer_variables + [v]) + checkpoint_path = saver.save(self.get_temp_dir()) + optimizer.minimize(loss_fn) + after_first_minimize = v.numpy() + # After we restore, the next step should be exactly the same as the one we + # just did. + saver.restore(checkpoint_path) + optimizer.minimize(loss_fn) + self.assertAllEqual(after_first_minimize, v.numpy()) + return optimizer_variables + + def testAdam(self): + optimizer = adam.AdamOptimizer(0.1) + self._optimizer_test_template(optimizer) + + def testGradientDescent(self): + optimizer = gradient_descent.GradientDescentOptimizer(0.02) + self.assertEqual(0, len(self._optimizer_test_template(optimizer))) + + def testMomentum(self): + optimizer = momentum.MomentumOptimizer( + learning_rate=0.03, + momentum=0.5) + self._optimizer_test_template(optimizer) + + def testRMSProp(self): + optimizer = rmsprop.RMSPropOptimizer(0.01) + self._optimizer_test_template(optimizer) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/eager/python/summary_writer.py b/tensorflow/contrib/eager/python/summary_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d8c41b545b3c9fd03af85f302ba05a394f085a4 --- /dev/null +++ b/tensorflow/contrib/eager/python/summary_writer.py @@ -0,0 +1,242 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TensorBoard Summary Writer for TensorFlow Eager Execution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import uuid + +from tensorflow.contrib.summary import gen_summary_ops +from tensorflow.python.eager import context +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import summary_op_util +from tensorflow.python.ops import variable_scope + + +def _maybe_cpu(v): + if isinstance(v, (ops.EagerTensor, ops.Tensor)): + return v.cpu() + else: + return v + + +def _summary_writer_function(name, tensor, function, family=None): + def record(): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + function(tag, scope) + return True + return record + + +class SummaryWriter(object): + """Writes summaries for TensorBoard, compatible with eager execution. + + This class is the supported way of writing TensorBoard summaries under + eager execution. + """ + + _CPU_DEVICE = "cpu:0" + + def __init__(self, + logdir, + max_queue=10, + flush_secs=120, + filename_suffix=""): + """Summary writer for TensorBoard, compatible with eager execution. + + If necessary, multiple instances of `SummaryWriter` can be created, with + distinct `logdir`s and `name`s. Each `SummaryWriter` instance will retain + its independent `global_step` counter and data writing destination. + + Example: + ```python + writer = tfe.SummaryWriter("my_model") + + # ... Code that sets up the model and data batches ... + + for _ in xrange(train_iters): + loss = model.train_batch(batch) + writer.scalar("loss", loss) + writer.step() + ``` + + Args: + logdir: Directory in which summary files will be written. + max_queue: Number of summary items to buffer before flushing to + filesystem. If 0, summaries will be flushed immediately. + flush_secs: Number of secondsbetween forced commits to disk. + filename_suffix: Suffix of the event protobuf files in which the summary + data are stored. + + Raises: + ValueError: If this constructor is called not under eager execution. + """ + # TODO(apassos, ashankar): Make this class and the underlying + # contrib.summary_ops compatible with graph model and remove this check. + if not context.in_eager_mode(): + raise ValueError( + "Use of SummaryWriter is currently supported only with eager " + "execution enabled. File an issue at " + "https://github.com/tensorflow/tensorflow/issues/new to express " + "interest in fixing this.") + + # TODO(cais): Consider adding name keyword argument, which if None or empty, + # will register the global global_step that training_util.get_global_step() + # can find. + with context.device(self._CPU_DEVICE): + self._name = uuid.uuid4().hex + self._global_step = 0 + self._global_step_tensor = variable_scope.get_variable( + "global_step/summary_writer/" + self._name, + shape=[], dtype=dtypes.int64, + initializer=init_ops.zeros_initializer()) + self._global_step_dirty = False + self._resource = gen_summary_ops.summary_writer(shared_name=self._name) + gen_summary_ops.create_summary_file_writer( + self._resource, logdir, max_queue, flush_secs, filename_suffix) + # Delete the resource when this object is deleted + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device=self._CPU_DEVICE) + + def step(self): + """Increment the global step counter of this SummaryWriter instance.""" + self._global_step += 1 + self._global_step_dirty = True + + @property + def global_step(self): + """Obtain the current global_step value of this SummaryWriter instance. + + Returns: + An `int` representing the current value of the global_step of this + `SummaryWriter` instance. + """ + return self._global_step + + def _update_global_step_tensor(self): + with context.device(self._CPU_DEVICE): + if self._global_step_dirty: + self._global_step_dirty = False + return state_ops.assign(self._global_step_tensor, self._global_step) + else: + return self._global_step_tensor + + def generic(self, name, tensor, metadata, family=None): + """Write a generic-type summary. + + Args: + name: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: A `Tensor` or compatible value type containing the value of the + summary. + metadata: Metadata about the summary. + family: Optional; if provided, used as the prefix of the summary tag name, + which controls the tab name used for display on Tensorboard. + """ + with context.device(self._CPU_DEVICE): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_summary( + self._resource, + self._update_global_step_tensor(), + _maybe_cpu(tensor), + tag, + _maybe_cpu(metadata), + name=scope) + + def scalar(self, name, tensor, family=None): + """Write a scalar summary. + + Args: + name: A name for the generated node. Will also serve as the series name in + TensorBoard. + tensor: A real numeric `Tensor` or compatible value type containing a + single value. + family: Optional; if provided, used as the prefix of the summary tag name, + which controls the tab name used for display on Tensorboard. + + Returns: + A summary writer function for scalars. + """ + with context.device(self._CPU_DEVICE): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_scalar_summary( + self._resource, self._update_global_step_tensor(), + tag, _maybe_cpu(tensor), name=scope) + + def histogram(self, name, tensor, family=None): + """Write a histogram summary. + + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + tensor: A real numeric `Tensor` or compatible value type. Any shape. + Values to use to build the histogram. + family: Optional; if provided, used as the prefix of the summary tag name, + which controls the tab name used for display on Tensorboard. + """ + with context.device(self._CPU_DEVICE): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_histogram_summary( + self._resource, self._update_global_step_tensor(), + tag, _maybe_cpu(tensor), name=scope) + + def image(self, name, tensor, bad_color=None, max_images=3, family=None): + """Write an image summary.""" + with context.device(self._CPU_DEVICE): + if bad_color is None: + bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_image_summary( + self._resource, self._update_global_step_tensor(), + tag, _maybe_cpu(tensor), bad_color_, max_images, + name=scope) + + def audio(self, name, tensor, sample_rate, max_outputs, family=None): + """Write an audio summary. + + Args: + name: A name for the generated node. Will also serve as a series name in + TensorBoard. + tensor: A 3-D `float32` `Tensor` of shape `[batch_size, frames, channels]` + or a 2-D `float32` `Tensor` of shape `[batch_size, frames]`, or + compatible value type. + sample_rate: A Scalar `float32` `Tensor` indicating the sample rate of the + signal in hertz. + max_outputs: Max number of batch elements to generate audio for. + family: Optional; if provided, used as the prefix of the summary tag name, + which controls the tab name used for display on Tensorboard. + """ + with context.device(self._CPU_DEVICE): + with summary_op_util.summary_scope( + name, family, values=[tensor]) as (tag, scope): + gen_summary_ops.write_audio_summary( + self._resource, self._update_global_step_tensor(), + tag, + _maybe_cpu(tensor), + sample_rate=_maybe_cpu(sample_rate), + max_outputs=max_outputs, + name=scope) diff --git a/tensorflow/contrib/eager/python/summary_writer_test.py b/tensorflow/contrib/eager/python/summary_writer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5ebb36d04fcba8f4558fa1c09716314af42f559f --- /dev/null +++ b/tensorflow/contrib/eager/python/summary_writer_test.py @@ -0,0 +1,150 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for eager execution SummaryWriter.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import shutil +import tempfile + +import numpy as np + +from tensorflow.contrib.eager.python import summary_writer +from tensorflow.core.util import event_pb2 +from tensorflow.python.eager import context +from tensorflow.python.eager import test +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import gfile + + +class SummaryWriterTest(test.TestCase): + + def setUp(self): + super(SummaryWriterTest, self).setUp() + self._test_device = "gpu:0" if context.num_gpus() else "cpu:0" + self._tmp_logdir = tempfile.mkdtemp() + with context.device(self._test_device): + # Use max_queue=0 so that summaries are immediately flushed to filesystem, + # making testing easier. + self._writer = summary_writer.SummaryWriter(self._tmp_logdir, max_queue=0) + + def tearDown(self): + if os.path.isdir(self._tmp_logdir): + shutil.rmtree(self._tmp_logdir) + super(SummaryWriterTest, self).tearDown() + + def _readLastEvent(self, logdir=None): + if not logdir: + logdir = self._tmp_logdir + files = [f for f in gfile.ListDirectory(logdir) + if not gfile.IsDirectory(os.path.join(logdir, f))] + file_path = os.path.join(logdir, files[0]) + records = list(tf_record.tf_record_iterator(file_path)) + event = event_pb2.Event() + event.ParseFromString(records[-1]) + return event + + def testGlobalStep(self): + with context.device(self._test_device): + orig_step = self._writer.global_step + self._writer.step() + self.assertEqual(orig_step + 1, self._writer.global_step) + self.assertEqual(orig_step + 1, self._writer.global_step) + self._writer.step() + self._writer.step() + self.assertEqual(orig_step + 3, self._writer.global_step) + + def testGenericSummary(self): + with context.device(self._test_device): + x = constant_op.constant(1337.0) + with context.device("cpu:0"): + metadata = constant_op.constant("foo") + self._writer.generic("x", x, metadata) + event = self._readLastEvent() + self.assertEqual("x", event.summary.value[0].tag) + + def testScalarSummary(self): + with context.device(self._test_device): + x = constant_op.constant(1337.0) + self._writer.scalar("x", x) + event = self._readLastEvent() + self.assertTrue("x", event.summary.value[0].tag) + self.assertEqual(1337.0, event.summary.value[0].simple_value) + + def testHistogramSummary(self): + with context.device(self._test_device): + y = constant_op.constant([1.0, 3.0, 3.0, 7.0]) + self._writer.histogram("y", y) + event = self._readLastEvent() + self.assertEqual("y", event.summary.value[0].tag) + self.assertTrue(event.summary.value[0].histo) + + def testImageSummary(self): + with context.device(self._test_device): + a = constant_op.constant([[10.0, 20.0], [-20.0, -10.0]]) + self._writer.histogram("image1", a) + event = self._readLastEvent() + self.assertEqual("image1", event.summary.value[0].tag) + self.assertTrue(event.summary.value[0].image) + + def testAudioSummary(self): + with context.device(self._test_device): + w = constant_op.constant(np.random.rand(3, 10, 2), dtype=dtypes.float32) + fs = constant_op.constant(44100.0, dtype=dtypes.float32) + max_outputs = 1 + self._writer.audio("audio1", w, fs, max_outputs) + event = self._readLastEvent() + self.assertTrue(event.summary.value[0].audio) + + def testTwoSummaryWritersGlobalStepsWorkWithoutCrosstalk(self): + tmp_logdir2 = os.path.join(self._tmp_logdir, "_writer2_") + writer2 = summary_writer.SummaryWriter(tmp_logdir2, max_queue=0) + + self.assertEqual(0, writer2.global_step) + self._writer.step() + self.assertEqual(0, writer2.global_step) + writer2.step() + writer2.step() + writer2.step() + self.assertEqual(3, writer2.global_step) + + x = constant_op.constant(1337.0) + writer_orig_step = self._writer.global_step + self._writer.step() + self._writer.scalar("x", x) + + event = self._readLastEvent() + self.assertEqual(writer_orig_step + 1, event.step) + + writer2.scalar("x", x) + event = self._readLastEvent(tmp_logdir2) + self.assertEqual(3, event.step) + + self._writer.step() + self._writer.scalar("x", x) + + event = self._readLastEvent() + self.assertEqual(writer_orig_step + 2, event.step) + + +# TODO(cais): Add performance benchmark for SummaryWriter. + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/eager/python/tfe.py b/tensorflow/contrib/eager/python/tfe.py index 6bf9aa1a3b5638937a68de416b0009f493bc253c..b6c687c82946ec62ccb90165791587dc335f13c7 100644 --- a/tensorflow/contrib/eager/python/tfe.py +++ b/tensorflow/contrib/eager/python/tfe.py @@ -18,7 +18,8 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice. To use, at program startup, call `tfe.enable_eager_execution()`. -@@device +@@metrics + @@list_devices @@num_gpus @@ -27,6 +28,7 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@implicit_value_and_gradients @@gradients_function @@value_and_gradients_function +@@GradientTape @@enable_tracing @@flush_trace @@ -44,8 +46,22 @@ To use, at program startup, call `tfe.enable_eager_execution()`. @@seterr @@Iterator +@@Network @@Saver +@@restore_variables_on_create @@Variable +@@get_optimizer_variables +@@EagerVariableStore + +@@in_eager_mode +@@in_graph_mode + +@@IsolateTest +@@run_test_in_graph_and_eager_modes + +@@DEVICE_PLACEMENT_EXPLICIT +@@DEVICE_PLACEMENT_WARN +@@DEVICE_PLACEMENT_SILENT """ from __future__ import absolute_import @@ -55,30 +71,42 @@ from __future__ import print_function # pylint:disable=g-bad-import-order,g-import-not-at-top,unused-import # +from tensorflow.contrib.eager.python import metrics from tensorflow.contrib.eager.python.datasets import Iterator +from tensorflow.contrib.eager.python.network import Network +from tensorflow.contrib.eager.python.saver import get_optimizer_variables +from tensorflow.contrib.eager.python.saver import restore_variables_on_create from tensorflow.contrib.eager.python.saver import Saver -from tensorflow.python.util.all_util import remove_undocumented from tensorflow.python.eager import backprop -from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager import function -from tensorflow.python.eager.context import device -from tensorflow.python.eager.context import enable_eager_execution +from tensorflow.python.eager.context import DEVICE_PLACEMENT_EXPLICIT +from tensorflow.python.eager.context import DEVICE_PLACEMENT_WARN +from tensorflow.python.eager.context import DEVICE_PLACEMENT_SILENT +from tensorflow.python.eager.context import in_eager_mode +from tensorflow.python.eager.context import in_graph_mode from tensorflow.python.eager.context import list_devices from tensorflow.python.eager.context import num_gpus -from tensorflow.python.eager.context import run from tensorflow.python.eager.core import enable_tracing +from tensorflow.python.eager.custom_gradient import custom_gradient from tensorflow.python.eager.execution_callbacks import add_execution_callback from tensorflow.python.eager.execution_callbacks import clear_execution_callbacks from tensorflow.python.eager.execution_callbacks import inf_callback from tensorflow.python.eager.execution_callbacks import inf_nan_callback from tensorflow.python.eager.execution_callbacks import nan_callback from tensorflow.python.eager.execution_callbacks import seterr +from tensorflow.python.framework.ops import enable_eager_execution +from tensorflow.python.framework.ops import eager_run as run +from tensorflow.python.framework.test_util import IsolateTest +from tensorflow.python.framework.test_util import run_in_graph_and_eager_modes as run_test_in_graph_and_eager_modes from tensorflow.python.ops.resource_variable_ops import ResourceVariable as Variable +from tensorflow.python.ops.variable_scope import EagerVariableStore +from tensorflow.python.util.all_util import remove_undocumented defun = function.defun implicit_gradients = backprop.implicit_grad implicit_value_and_gradients = backprop.implicit_val_and_grad gradients_function = backprop.gradients_function value_and_gradients_function = backprop.val_and_grad_function +GradientTape = backprop.GradientTape # pylint: disable=invalid-name remove_undocumented(__name__) diff --git a/tensorflow/contrib/eager/python/tfe_test.py b/tensorflow/contrib/eager/python/tfe_test.py index 1adce2048b27813cb20c0d4f95ce3b08350bc956..0dedb2fd7c0905801cd87c239ff2ee09eecb6080 100644 --- a/tensorflow/contrib/eager/python/tfe_test.py +++ b/tensorflow/contrib/eager/python/tfe_test.py @@ -17,13 +17,20 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import tempfile + from tensorflow.contrib.eager.python import tfe from tensorflow.python.framework import constant_op from tensorflow.python.framework import errors +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import numerics +from tensorflow.python.ops import variables from tensorflow.python.platform import test +from tensorflow.python.summary import summary +from tensorflow.python.summary.writer import writer class TFETest(test_util.TensorFlowTestCase): @@ -38,6 +45,11 @@ class TFETest(test_util.TensorFlowTestCase): r'indices = 7 is not in \[0, 3\)'): array_ops.gather([0, 1, 2], 7) + def testVariableError(self): + with self.assertRaisesRegexp( + RuntimeError, r'Variable not supported in Eager mode'): + variables.Variable(initial_value=1.0) + def testGradients(self): def square(x): @@ -66,8 +78,7 @@ class TFETest(test_util.TensorFlowTestCase): return y, grad_fn - # TODO(ashankar): This [0] should ideally not be needed. - grad = tfe.gradients_function(f, [0]) + grad = tfe.gradients_function(f) self.assertEquals([12], [x.numpy() for x in grad(3)]) def testGPU(self): @@ -75,17 +86,17 @@ class TFETest(test_util.TensorFlowTestCase): self.skipTest('No GPUs available') # tf.Tensor.as_gpu_device() moves a tensor to GPU. - x = constant_op.constant([[1., 2.], [3., 4.]]).as_gpu_tensor() - # Alternatively, tfe.device() as a context manager places tensors and + x = constant_op.constant([[1., 2.], [3., 4.]]).gpu() + # Alternatively, tf.device() as a context manager places tensors and # operations. - with tfe.device('gpu:0'): + with ops.device('gpu:0'): x += 1. # Without a device context, heuristics are used to place ops. # In this case, ops.reduce_mean runs on the GPU. reduction_indices = range(x.shape.ndims) m = math_ops.reduce_mean(x, reduction_indices) # m is on GPU, bring it back to CPU and compare. - self.assertEqual(3.5, m.as_cpu_tensor().numpy()) + self.assertEqual(3.5, m.cpu().numpy()) def testListDevices(self): # Expect at least one device. @@ -95,12 +106,33 @@ class TFETest(test_util.TensorFlowTestCase): devices = tfe.list_devices() self.assertEqual(len(devices) - 1, tfe.num_gpus()) - def testCallingEnableEagerExecutionMoreThanOnce(self): - # Note that eager.test.main() has already invoked enable_eager_exceution(). + def testAddCheckNumericsOpsRaisesError(self): + with self.assertRaisesRegexp( + RuntimeError, + r'add_check_numerics_ops\(\) is not compatible with eager execution'): + numerics.add_check_numerics_ops() + + def testClassicSummaryOpsErrorOut(self): + x = constant_op.constant(42) + x_summary = summary.scalar('x', x) + y = constant_op.constant([1, 3, 3, 7]) + y_summary = summary.histogram('hist', y) + + with self.assertRaisesRegexp( + RuntimeError, + r'Merging tf\.summary\.\* ops is not compatible with eager execution'): + summary.merge([x_summary, y_summary]) + + with self.assertRaisesRegexp( + RuntimeError, + r'Merging tf\.summary\.\* ops is not compatible with eager execution'): + summary.merge_all() + + def testClassicSummaryFileWriterErrorsOut(self): with self.assertRaisesRegexp( - ValueError, r'Do not call tfe\.%s more than once in the same process' % - tfe.enable_eager_execution.__name__): - tfe.enable_eager_execution() + RuntimeError, + r'tf\.summary\.FileWriter is not compatible with eager execution'): + writer.FileWriter(tempfile.mkdtemp()) if __name__ == '__main__': diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD index dbfd4655c2f370613d99436cbf3571833fded200..6eb2cfdaca7840c4a5dd8cffc9620aaf3f96a1de 100644 --- a/tensorflow/contrib/estimator/BUILD +++ b/tensorflow/contrib/estimator/BUILD @@ -7,6 +7,7 @@ package( licenses(["notice"]) # Apache 2.0 load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") filegroup( name = "all_files", @@ -30,6 +31,7 @@ py_library( ":head", ":logit_fns", ":multi_head", + ":replicate_model_fn", "//tensorflow/python:util", ], ) @@ -50,7 +52,10 @@ py_test( size = "small", srcs = ["python/estimator/dnn_test.py"], srcs_version = "PY2AND3", - tags = ["no_pip"], + tags = [ + "no_pip", + "notsan", + ], deps = [ ":dnn", ":head", @@ -76,17 +81,19 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/python:clip_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", - "//tensorflow/python:util", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:util", + "@six_archive//:six", ], ) py_test( name = "extenders_test", - size = "small", + size = "medium", srcs = ["python/estimator/extenders_test.py"], srcs_version = "PY2AND3", deps = [ @@ -96,10 +103,11 @@ py_test( "//tensorflow/python:constant_op", "//tensorflow/python:framework_ops", "//tensorflow/python:metrics", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python:variables", + "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:linear", - "//tensorflow/python/estimator:run_config", "//tensorflow/python/feature_column", "//third_party/py/numpy", ], @@ -127,7 +135,9 @@ py_library( "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/estimator:util", "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", ], ) @@ -142,9 +152,11 @@ py_test( "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:errors", "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", @@ -179,7 +191,8 @@ py_test( deps = [ ":logit_fns", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:session", "//tensorflow/python/estimator:model_fn", ], ) @@ -216,9 +229,69 @@ py_test( "//tensorflow/python:string_ops", "//tensorflow/python/estimator:metric_keys", "//tensorflow/python/estimator:model_fn", - "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:signature_constants", "//third_party/py/numpy", "@six_archive//:six", ], ) + +py_library( + name = "replicate_model_fn", + srcs = [ + "python/estimator/replicate_model_fn.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:device", + "//tensorflow/python:device_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:util", + "@six_archive//:six", + ], +) + +cuda_py_test( + name = "replicate_model_fn_test", + size = "small", + srcs = ["python/estimator/replicate_model_fn_test.py"], + additional_deps = [ + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:dnn", + "//tensorflow/python/estimator:export_export", + "//tensorflow/python/estimator:export_output", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//tensorflow/python/estimator:optimizers", + "//tensorflow/python/estimator:prediction_keys", + "//tensorflow/python/feature_column", + "//tensorflow/python/ops/losses", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", + "//tensorflow/python:platform", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + ":replicate_model_fn", + ], + tags = ["requires-gpu-sm35"], +) diff --git a/tensorflow/contrib/estimator/__init__.py b/tensorflow/contrib/estimator/__init__.py index cd8bdcc12b65ce94836bc449ac101c7233b6340d..cf727264cd5116915f6bd7f285e470cbc2e2742a 100644 --- a/tensorflow/contrib/estimator/__init__.py +++ b/tensorflow/contrib/estimator/__init__.py @@ -32,6 +32,7 @@ _allowed_symbols = [ 'add_metrics', 'binary_classification_head', 'clip_gradients_by_norm', + 'forward_features', 'multi_class_head', 'multi_head', 'multi_label_head', diff --git a/tensorflow/contrib/estimator/python/estimator/extenders.py b/tensorflow/contrib/estimator/python/estimator/extenders.py index e5304f1fae37aa2e4187115f1963b4822a06ba88..29c3c7358534f6e8ebbd31cbfcd7e34086d9b506 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders.py @@ -18,12 +18,16 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util as estimator_util +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib from tensorflow.python.ops import clip_ops from tensorflow.python.training import optimizer as optimizer_lib -from tensorflow.python.util import tf_inspect + _VALID_METRIC_FN_ARGS = set(['features', 'labels', 'predictions', 'config']) @@ -132,6 +136,111 @@ def clip_gradients_by_norm(optimizer, clip_norm): name='ClipByNorm' + optimizer.get_name()) +def forward_features(estimator, keys=None): + """Forward features to predictions dictionary. + + In some cases, user wants to see some of the features in estimators prediction + output. As an example, consider a batch prediction service: The service simply + runs inference on the users graph and returns the results. Keys are essential + because there is no order guarantee on the outputs so they need to be rejoined + to the inputs via keys or transclusion of the inputs in the outputs. + + Example: + + ```python + def input_fn(): + features, labels = ... + features['unique_example_id'] = ... + features, labels + + estimator = tf.estimator.LinearClassifier(...) + estimator = tf.contrib.estimator.forward_features( + estimator, 'unique_example_id') + estimator.train(...) + assert 'unique_example_id' in estimator.predict(...) + ``` + + Args: + estimator: A ${tf.estimator.Estimator} object. + keys: a `string` or a `list` of `string`. If it is `None`, all of the + `features` in `dict` is forwarded to the `predictions`. If it is a + `string`, only given key is forwarded. If it is a `list` of strings, all + the given `keys` are forwarded. + + Returns: + A new ${tf.estimator.Estimator} which forwards features to predictions. + + Raises: + ValueError: + * if `keys` is already part of `predictions`. We don't allow + override. + * if 'keys' does not exist in `features`. + * if feature key refers to a `SparseTensor`, since we don't support + `SparseTensor` in `predictions`. `SparseTensor` is common in `features`. + TypeError: if `keys` type is not one of `string` or list/tuple of `string`. + """ + + def verify_key_types(keys): # pylint: disable=missing-docstring + if keys is None: + return keys + if isinstance(keys, six.string_types): + return [keys] + if not isinstance(keys, (list, tuple)): + raise TypeError('keys should be either a string or a list of strings. ' + 'Given: {}'.format(type(keys))) + for key in keys: + if not isinstance(key, six.string_types): + raise TypeError('All items in the given keys list should be a string. ' + 'There exist an item with type: {}'.format(type(key))) + return keys + + def get_keys(features): + if keys is None: + return features.keys() + return keys + + def verify_keys_and_predictions(features, predictions): + if not isinstance(predictions, dict): + raise ValueError( + 'Predictions should be a dict to be able to forward features. ' + 'Given: {}'.format(type(predictions))) + for key in get_keys(features): + if key not in features: + raise ValueError( + 'keys should be exist in features. Key "{}" is not in features ' + 'dict. features dict has following keys: {}. Please check ' + 'arguments of forward_features.'.format(key, features.keys())) + if key in predictions: + raise ValueError( + 'Cannot forward feature key ({}). Since it does exist in ' + 'predictions. Existing prediction keys: {}. Please check arguments ' + 'of forward_features.'.format(key, predictions.keys())) + + keys = verify_key_types(keys) + + def new_model_fn(features, labels, mode, config): # pylint: disable=missing-docstring + spec = estimator.model_fn(features, labels, mode, config) + predictions = spec.predictions + if predictions is None: + return spec + verify_keys_and_predictions(features, predictions) + for key in get_keys(features): + feature = sparse_tensor_lib.convert_to_tensor_or_sparse_tensor( + features[key]) + if not isinstance(feature, ops.Tensor): + raise ValueError( + 'Forwarded feature ({}) should be a Tensor. Please use keys ' + 'argument of forward_features to filter unwanted features. Type of ' + 'features[{}] is {}.'.format(key, key, type(feature))) + predictions[key] = feature + return spec._replace(predictions=predictions) + + return estimator_lib.Estimator( + model_fn=new_model_fn, + model_dir=estimator.model_dir, + config=estimator.config) + + class _TransformGradients(optimizer_lib.Optimizer): """Add given gradient transformation to the optimizer.""" @@ -208,9 +317,6 @@ class _TransformGradients(optimizer_lib.Optimizer): def _verify_metric_fn_args(metric_fn): args = set(estimator_util.fn_args(metric_fn)) - if tf_inspect.ismethod(metric_fn): - if 'self' in args: - args.remove('self') invalid_args = list(args - _VALID_METRIC_FN_ARGS) if invalid_args: raise ValueError('metric_fn (%s) has following not expected args: %s' % diff --git a/tensorflow/contrib/estimator/python/estimator/extenders_test.py b/tensorflow/contrib/estimator/python/estimator/extenders_test.py index d58a0a12943e9afaa4f0e9089df4dfadd667b1ed..5f4a3cc902c9cc07c0688ad41dab7391a641c133 100644 --- a/tensorflow/contrib/estimator/python/estimator/extenders_test.py +++ b/tensorflow/contrib/estimator/python/estimator/extenders_test.py @@ -22,11 +22,12 @@ import numpy as np from tensorflow.contrib.data.python.ops import dataset_ops from tensorflow.contrib.estimator.python.estimator import extenders -from tensorflow.python.estimator import run_config +from tensorflow.python.estimator import estimator_lib from tensorflow.python.estimator.canned import linear from tensorflow.python.feature_column import feature_column as fc from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -82,7 +83,7 @@ class AddMetricsTest(test.TestCase): self.assertIn('x', features) self.assertIsNotNone(labels) self.assertIn('logistic', predictions) - self.assertTrue(isinstance(config, run_config.RunConfig)) + self.assertTrue(isinstance(config, estimator_lib.RunConfig)) return {} estimator = extenders.add_metrics(estimator, metric_fn) @@ -98,7 +99,7 @@ class AddMetricsTest(test.TestCase): self.assertIn('x', features) self.assertIsNotNone(labels) self.assertIn('logistic', predictions) - self.assertTrue(isinstance(config, run_config.RunConfig)) + self.assertTrue(isinstance(config, estimator_lib.RunConfig)) return {} estimator = extenders.add_metrics(estimator, metric_fn) @@ -159,5 +160,141 @@ class ClipGradientsByNormTest(test.TestCase): self.assertEqual('ClipByNormGradientDescent', optimizer.get_name()) +class ForwardFeaturesTest(test.TestCase): + """Tests forward_features.""" + + def test_forward_single_key(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + + self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) + estimator = extenders.forward_features(estimator, 'id') + predictions = next(estimator.predict(input_fn=input_fn)) + self.assertIn('id', predictions) + self.assertEqual(101, predictions['id']) + + def test_forward_list(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + + self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) + estimator = extenders.forward_features(estimator, ['x', 'id']) + predictions = next(estimator.predict(input_fn=input_fn)) + self.assertIn('id', predictions) + self.assertIn('x', predictions) + self.assertEqual(101, predictions['id']) + self.assertEqual(3., predictions['x']) + + def test_forward_all(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + + self.assertNotIn('id', next(estimator.predict(input_fn=input_fn))) + self.assertNotIn('x', next(estimator.predict(input_fn=input_fn))) + estimator = extenders.forward_features(estimator) + predictions = next(estimator.predict(input_fn=input_fn)) + self.assertIn('id', predictions) + self.assertIn('x', predictions) + self.assertEqual(101, predictions['id']) + self.assertEqual(3., predictions['x']) + + def test_key_should_be_string(self): + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + with self.assertRaisesRegexp(TypeError, 'keys should be either a string'): + extenders.forward_features(estimator, estimator) + + def test_key_should_be_list_of_string(self): + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + with self.assertRaisesRegexp(TypeError, 'should be a string'): + extenders.forward_features(estimator, ['x', estimator]) + + def test_key_should_be_in_features(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]}, [[1.], [2.]] + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + + estimator = extenders.forward_features(estimator, 'y') + with self.assertRaisesRegexp(ValueError, + 'keys should be exist in features'): + next(estimator.predict(input_fn=input_fn)) + + def test_forwarded_feature_should_not_be_a_sparse_tensor(self): + + def input_fn(): + return { + 'x': [[3.], [5.]], + 'id': + sparse_tensor.SparseTensor( + values=['1', '2'], + indices=[[0, 0], [1, 0]], + dense_shape=[2, 1]) + }, [[1.], [2.]] + + estimator = linear.LinearRegressor([fc.numeric_column('x')]) + estimator.train(input_fn=input_fn, steps=1) + + estimator = extenders.forward_features(estimator) + with self.assertRaisesRegexp(ValueError, + 'Forwarded feature.* should be a Tensor.'): + next(estimator.predict(input_fn=input_fn)) + + def test_predictions_should_be_dict(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]} + + def model_fn(features, mode): + del features + global_step = training.get_global_step() + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant([5.]), + predictions=constant_op.constant([5.]), + train_op=global_step.assign_add(1)) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + estimator.train(input_fn=input_fn, steps=1) + + estimator = extenders.forward_features(estimator) + with self.assertRaisesRegexp(ValueError, 'Predictions should be a dict'): + next(estimator.predict(input_fn=input_fn)) + + def test_should_not_conflict_with_existing_predictions(self): + + def input_fn(): + return {'x': [[3.], [5.]], 'id': [[101], [102]]} + + def model_fn(features, mode): + del features + global_step = training.get_global_step() + return estimator_lib.EstimatorSpec( + mode, + loss=constant_op.constant([5.]), + predictions={'x': constant_op.constant([5.])}, + train_op=global_step.assign_add(1)) + + estimator = estimator_lib.Estimator(model_fn=model_fn) + estimator.train(input_fn=input_fn, steps=1) + + estimator = extenders.forward_features(estimator) + with self.assertRaisesRegexp(ValueError, 'Cannot forward feature key'): + next(estimator.predict(input_fn=input_fn)) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/head.py b/tensorflow/contrib/estimator/python/estimator/head.py index 9b14622ff6436efcf66dae311f773c8375b2cafa..7c992c99ed3fb05d5f2c306304b7084584c201e4 100644 --- a/tensorflow/contrib/estimator/python/estimator/head.py +++ b/tensorflow/contrib/estimator/python/estimator/head.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import head as head_lib from tensorflow.python.estimator.canned import metric_keys from tensorflow.python.estimator.canned import prediction_keys @@ -33,8 +34,11 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib from tensorflow.python.ops import sparse_ops from tensorflow.python.ops.losses import losses +from tensorflow.python.saved_model import signature_constants from tensorflow.python.summary import summary +_DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + def multi_class_head(n_classes, weight_column=None, @@ -59,7 +63,7 @@ def multi_class_head(n_classes, `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi class classification. @@ -98,7 +102,7 @@ def binary_classification_head( `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for binary classification. @@ -129,7 +133,7 @@ def regression_head(weight_column=None, of the last dimension of the labels `Tensor` (typically, this has shape `[batch_size, label_dimension]`). name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for linear regression. @@ -144,6 +148,7 @@ def multi_label_head(n_classes, weight_column=None, thresholds=None, label_vocabulary=None, + loss_fn=None, name=None): """Creates a `_Head` for multi-label classification. @@ -155,6 +160,12 @@ def multi_label_head(n_classes, multi-hot tensor of shape `[batch_size, n_classes]`, or as an integer `SparseTensor` of class indices. + Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or + `(labels, logits, features)` as arguments and returns unreduced loss with + shape `[batch_size, 1]`. `loss_fn` must support indicator `labels` with shape + `[batch_size, n_classes]`. Namely, the head applies `label_vocabulary` to the + input labels before passing them to `loss_fn`. + Args: n_classes: Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). @@ -171,8 +182,9 @@ def multi_label_head(n_classes, [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. + loss_fn: Optional loss function. name: name of the head. If provided, summary and metrics keys will be - suffixed by `"/" + name`. + suffixed by `"/" + name`. Also used as `name_scope` when creating ops. Returns: An instance of `_Head` for multi-label classification. @@ -198,9 +210,11 @@ def multi_label_head(n_classes, raise ValueError( 'Length of label_vocabulary must be n_classes ({}). ' 'Given: {}'.format(n_classes, len(label_vocabulary))) + if loss_fn: + _validate_loss_fn_args(loss_fn) return _MultiLabelHead( n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, - label_vocabulary=label_vocabulary, name=name) + label_vocabulary=label_vocabulary, loss_fn=loss_fn, name=name) class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access @@ -211,11 +225,13 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access weight_column=None, thresholds=None, label_vocabulary=None, + loss_fn=None, name=None): self._n_classes = n_classes self._weight_column = weight_column self._thresholds = thresholds self._label_vocabulary = label_vocabulary + self._loss_fn = loss_fn self._name = name @property @@ -227,6 +243,12 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access return self._n_classes def _process_labels(self, labels): + if labels is None: + raise ValueError( + 'You must provide a labels Tensor. Given: None. ' + 'Suggested troubleshooting steps: Check that your data contain ' + 'your label feature. Check that your input_fn properly parses and ' + 'returns labels.') if isinstance(labels, sparse_tensor.SparseTensor): if labels.dtype == dtypes.string: label_ids_values = lookup_ops.index_table_from_tensor( @@ -254,19 +276,34 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access def create_loss(self, features, mode, logits, labels): """See `Head`.""" - del mode, features # Unused for this head. + del mode # Unused for this head. processed_labels = self._process_labels(labels) - unweighted_loss = losses.sigmoid_cross_entropy( - multi_class_labels=processed_labels, logits=logits, - reduction=losses.Reduction.NONE) - return head_lib.LossAndLabels( - unweighted_loss=unweighted_loss, + if self._loss_fn: + unweighted_loss = _call_loss_fn( + loss_fn=self._loss_fn, labels=processed_labels, logits=logits, + features=features) + else: + unweighted_loss = losses.sigmoid_cross_entropy( + multi_class_labels=processed_labels, logits=logits, + reduction=losses.Reduction.NONE) + # Averages loss over classes. + unweighted_loss = math_ops.reduce_mean( + unweighted_loss, axis=-1, keep_dims=True) + weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access, + weighted_sum_loss = losses.compute_weighted_loss( + unweighted_loss, weights=weights, reduction=losses.Reduction.SUM) + # _weights() can return 1. + example_weight_sum = math_ops.reduce_sum( + weights * array_ops.ones_like(unweighted_loss)) + return head_lib.LossSpec( + weighted_sum_loss=weighted_sum_loss, + example_weight_sum=example_weight_sum, processed_labels=processed_labels) def create_estimator_spec( self, features, mode, logits, labels=None, train_op_fn=None): """See `Head`.""" - with ops.name_scope('head'): + with ops.name_scope(self._name, 'head'): logits = head_lib._check_logits(logits, self.logits_dimension) # pylint:disable=protected-access # Predict. @@ -278,32 +315,35 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access pred_keys.PROBABILITIES: probabilities, } if mode == model_fn.ModeKeys.PREDICT: + classifier_output = head_lib._classification_output( # pylint:disable=protected-access + scores=probabilities, n_classes=self._n_classes, + label_vocabulary=self._label_vocabulary) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.PREDICT, predictions=predictions, export_outputs={ - '': export_output.ClassificationOutput(scores=probabilities) + _DEFAULT_SERVING_KEY: classifier_output, + head_lib._CLASSIFY_SERVING_KEY: classifier_output, # pylint:disable=protected-access + head_lib._PREDICT_SERVING_KEY: ( # pylint:disable=protected-access + export_output.PredictOutput(predictions)) }) + (weighted_sum_loss, example_weight_sum, + processed_labels) = self.create_loss( + features=features, mode=mode, logits=logits, labels=labels) + # Eval. - unweighted_loss, processed_labels = self.create_loss( - features=features, mode=mode, logits=logits, labels=labels) - # Averages loss over classes. - per_example_loss = math_ops.reduce_mean( - unweighted_loss, axis=-1, keep_dims=True) - weights = head_lib._weights(features, self._weight_column) # pylint:disable=protected-access - training_loss = losses.compute_weighted_loss( - per_example_loss, weights=weights, reduction=losses.Reduction.SUM) if mode == model_fn.ModeKeys.EVAL: return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.EVAL, predictions=predictions, - loss=training_loss, + loss=weighted_sum_loss, eval_metric_ops=self._eval_metric_ops( labels=processed_labels, probabilities=probabilities, - weights=weights, - per_example_loss=per_example_loss)) + weights=head_lib._weights(features, self._weight_column), # pylint:disable=protected-access, + weighted_sum_loss=weighted_sum_loss, + example_weight_sum=example_weight_sum)) # Train. if train_op_fn is None: @@ -311,37 +351,43 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access with ops.name_scope(''): summary.scalar( head_lib._summary_key(self._name, metric_keys.MetricKeys.LOSS), # pylint:disable=protected-access - training_loss) + weighted_sum_loss) summary.scalar( head_lib._summary_key( # pylint:disable=protected-access self._name, metric_keys.MetricKeys.LOSS_MEAN), - losses.compute_weighted_loss( - unweighted_loss, weights=weights, - reduction=losses.Reduction.MEAN)) + weighted_sum_loss / example_weight_sum) return model_fn.EstimatorSpec( mode=model_fn.ModeKeys.TRAIN, predictions=predictions, - loss=training_loss, - train_op=train_op_fn(training_loss)) + loss=weighted_sum_loss, + train_op=train_op_fn(weighted_sum_loss)) - def _eval_metric_ops(self, labels, probabilities, weights, per_example_loss): + def _eval_metric_ops(self, labels, probabilities, weights, weighted_sum_loss, + example_weight_sum): """Returns a dict of metrics for eval_metric_ops.""" with ops.name_scope( - None, 'metrics', [labels, probabilities, weights, per_example_loss]): + None, 'metrics', + [labels, probabilities, weights, weighted_sum_loss, example_weight_sum + ]): keys = metric_keys.MetricKeys metric_ops = { # Estimator already adds a metric for loss. head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access metrics_lib.mean( - per_example_loss, weights=weights, name=keys.LOSS_MEAN), + # Both values and weights here are reduced, scalar Tensors. + # values is the actual mean we want, but we pass the scalar + # example_weight_sum in order to return the correct update_op + # alongside the value_op for streaming metrics. + values=(weighted_sum_loss / example_weight_sum), + weights=example_weight_sum, + name=keys.LOSS_MEAN), head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access - metrics_lib.auc( - labels=labels, predictions=probabilities, weights=weights, - name=keys.AUC), + metrics_lib.auc(labels=labels, predictions=probabilities, + weights=weights, name=keys.AUC), head_lib._summary_key(self._name, keys.AUC_PR): # pylint:disable=protected-access - metrics_lib.auc( - labels=labels, predictions=probabilities, weights=weights, - curve='PR', name=keys.AUC_PR), + metrics_lib.auc(labels=labels, predictions=probabilities, + weights=weights, curve='PR', + name=keys.AUC_PR), } for threshold in self._thresholds: accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold @@ -371,3 +417,52 @@ class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access threshold=threshold, name=recall_key)) return metric_ops + + +def _validate_loss_fn_args(loss_fn): + """Validates loss_fn arguments. + + Required arguments: labels, logits. + Optional arguments: features. + + Args: + loss_fn: The loss function. + Raises: + ValueError: If the signature is unexpected. + """ + loss_fn_args = util.fn_args(loss_fn) + for required_arg in ['labels', 'logits']: + if required_arg not in loss_fn_args: + raise ValueError( + 'loss_fn must contain argument: {}. ' + 'Given arguments: {}'.format(required_arg, loss_fn_args)) + invalid_args = list(set(loss_fn_args) - set(['labels', 'logits', 'features'])) + if invalid_args: + raise ValueError('loss_fn has unexpected args: {}'.format(invalid_args)) + + +def _call_loss_fn(loss_fn, labels, logits, features): + """Calls loss_fn and checks the returned shape. + + Args: + loss_fn: The loss function. + labels: Processed labels Tensor. + logits: Logits Tensor of shape [batch_size, logits_dimension]. + features: Features dict. + Returns: + Loss Tensor with shape [batch_size, 1]. + """ + loss_fn_args = util.fn_args(loss_fn) + kwargs = {} + if 'features' in loss_fn_args: + kwargs['features'] = features + unweighted_loss = loss_fn(labels=labels, logits=logits, **kwargs) + batch_size = array_ops.shape(logits)[0] + loss_shape = array_ops.shape(unweighted_loss) + check_shape_op = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(loss_shape, [batch_size, 1])), + data=[ + 'loss_fn must return Tensor of shape [batch_size, 1]. Given: ', + loss_shape]) + with ops.control_dependencies([check_shape_op]): + return array_ops.identity(unweighted_loss) diff --git a/tensorflow/contrib/estimator/python/estimator/head_test.py b/tensorflow/contrib/estimator/python/estimator/head_test.py index 9dd9e433277304b320ac17d6478383531f114806..972ce6163d5b0f580b08888bd69dff0d40fefa34 100644 --- a/tensorflow/contrib/estimator/python/estimator/head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/head_test.py @@ -32,6 +32,8 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops from tensorflow.python.ops import string_ops from tensorflow.python.platform import test from tensorflow.python.saved_model import signature_constants @@ -79,9 +81,13 @@ def _sigmoid(logits): def _sigmoid_cross_entropy(labels, logits): + """Returns sigmoid cross entropy averaged over classes.""" sigmoid_logits = _sigmoid(logits) - return (-labels * np.log(sigmoid_logits) - -(1 - labels) * np.log(1 - sigmoid_logits)) + unreduced_result = ( + -labels * np.log(sigmoid_logits) + -(1 - labels) * np.log(1 - sigmoid_logits)) + # Mean over classes + return np.mean(unreduced_result, axis=-1, keepdims=True) class MultiLabelHead(test.TestCase): @@ -126,6 +132,37 @@ class MultiLabelHead(test.TestCase): r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) + def test_loss_fn_arg_labels_missing(self): + def _loss_fn(logits): + del logits # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: labels\. ' + r'Given arguments: \(\'logits\',\)'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_logits_missing(self): + def _loss_fn(labels): + del labels # unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn must contain argument: logits\. ' + r'Given arguments: \(\'labels\',\)'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_features_ok(self): + def _loss_fn(labels, logits, features): + del labels, logits, features # Unused + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + + def test_loss_fn_arg_invalid(self): + def _loss_fn(labels, logits, name=None): + del labels, logits, name # Unused + with self.assertRaisesRegexp( + ValueError, + r'loss_fn has unexpected args: \[\'name\'\]'): + head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) + def test_name(self): head = head_lib.multi_label_head(n_classes=4, name='foo') self.assertEqual('foo', head.name) @@ -138,6 +175,7 @@ class MultiLabelHead(test.TestCase): logits = np.array( [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) expected_probabilities = _sigmoid(logits) + expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2 spec = head.create_estimator_spec( features={'x': np.array(((42,),), dtype=np.int32)}, @@ -145,7 +183,8 @@ class MultiLabelHead(test.TestCase): logits=logits) self.assertItemsEqual( - ('', _DEFAULT_SERVING_KEY), spec.export_outputs.keys()) + (_DEFAULT_SERVING_KEY, 'predict', 'classification'), + spec.export_outputs.keys()) # Assert predictions and export_outputs. with self.test_session() as sess: @@ -161,6 +200,29 @@ class MultiLabelHead(test.TestCase): self.assertAllClose( expected_probabilities, sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) + self.assertAllEqual( + expected_export_classes, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) + + def test_predict_with_label_vocabulary(self): + n_classes = 4 + head = head_lib.multi_label_head( + n_classes, label_vocabulary=['foo', 'bar', 'foobar', 'barfoo']) + + logits = np.array( + [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) + expected_export_classes = [[b'foo', b'bar', b'foobar', b'barfoo']] * 2 + + spec = head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.PREDICT, + logits=logits) + + with self.test_session() as sess: + _initialize_variables(self, spec.scaffold) + self.assertAllEqual( + expected_export_classes, + sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) def test_weight_should_not_impact_prediction(self): n_classes = 4 @@ -200,17 +262,17 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - expected_unweighted_loss = _sigmoid_cross_entropy( - labels=labels, logits=logits) - actual_unweighted_loss, _ = head.create_loss( + expected_weighted_sum_loss = np.sum( + _sigmoid_cross_entropy(labels=labels, logits=logits)) + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) - self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval()) + self.assertAllClose(expected_weighted_sum_loss, + actual_weighted_sum_loss.eval()) def test_eval_create_loss_large_logits(self): """Tests head.create_loss for eval mode and large logits.""" @@ -224,17 +286,19 @@ class MultiLabelHead(test.TestCase): # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_unweighted_loss = np.array( - [[10., 10.], [15., 0.]], dtype=np.float32) - actual_unweighted_loss, _ = head.create_loss( + expected_weighted_sum_loss = np.sum( + np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels) + labels=labels)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + expected_weighted_sum_loss, + actual_weighted_sum_loss.eval(), + atol=1e-4) def test_eval_create_loss_labels_wrong_shape(self): """Tests head.create_loss for eval mode when labels has the wrong shape.""" @@ -243,23 +307,85 @@ class MultiLabelHead(test.TestCase): logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) labels_placeholder = array_ops.placeholder(dtype=dtypes.int64) - actual_unweighted_loss, _ = head.create_loss( + actual_weighted_sum_loss = head.create_loss( features={'x': np.array(((42,),), dtype=np.int32)}, mode=model_fn.ModeKeys.EVAL, logits=logits, - labels=labels_placeholder) + labels=labels_placeholder)[0] with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels shape must be \[batch_size, 2\]\. Given: \] \[2 1\]'): - actual_unweighted_loss.eval( - {labels_placeholder: np.array([[1], [1]], dtype=np.int64)}) + actual_weighted_sum_loss.eval({ + labels_placeholder: np.array([[1], [1]], dtype=np.int64) + }) with self.assertRaisesRegexp( errors.InvalidArgumentError, r'labels shape must be \[batch_size, 2\]\. Given: \] \[2\]'): - actual_unweighted_loss.eval( - {labels_placeholder: np.array([1, 1], dtype=np.int64)}) + actual_weighted_sum_loss.eval({ + labels_placeholder: np.array([1, 1], dtype=np.int64) + }) + + def test_eval_create_loss_loss_fn(self): + """Tests head.create_loss for eval mode and custom loss_fn.""" + loss = np.array([[1.], [2.]], dtype=np.float32) + logits_input = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels_input = np.array([[1, 0], [1, 1]], dtype=np.int64) + def _loss_fn(labels, logits): + check_labels = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(labels, labels_input)), + data=[labels]) + check_logits = control_flow_ops.Assert( + math_ops.reduce_all(math_ops.equal(logits, logits_input)), + data=[logits]) + with ops.control_dependencies([check_labels, check_logits]): + return constant_op.constant(loss) + head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) + + actual_weighted_sum_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits_input, + labels=labels_input)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + self.assertAllClose(np.sum(loss), actual_weighted_sum_loss.eval()) + + def test_eval_create_loss_loss_fn_wrong_shape(self): + """Tests custom loss_fn that returns Tensor of unexpected shape.""" + loss = np.array([1., 2.], dtype=np.float32) + def _loss_fn(labels, logits): + del labels, logits # Unused + return constant_op.constant(loss) + head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) + + logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) + labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + actual_weighted_sum_loss = head.create_loss( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=logits, + labels=labels)[0] + with self.test_session(): + _initialize_variables(self, monitored_session.Scaffold()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + r'loss_fn must return Tensor of shape \[batch_size, 1\]\. ' + r'Given: \] \[2\]'): + actual_weighted_sum_loss.eval() + + def test_eval_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.EVAL, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None) def _test_eval(self, head, logits, labels, expected_loss, expected_metrics): spec = head.create_estimator_spec( @@ -298,10 +424,8 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. - expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes - ) + # Sum over examples. + expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) keys = metric_keys.MetricKeys expected_metrics = { # Average loss over examples. @@ -330,10 +454,9 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / - n_classes + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) ) keys = metric_keys.MetricKeys expected_metrics = { @@ -364,10 +487,9 @@ class MultiLabelHead(test.TestCase): labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) / - n_classes + np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) ) keys = metric_keys.MetricKeys expected_metrics = { @@ -394,9 +516,9 @@ class MultiLabelHead(test.TestCase): labels = np.array([[1, 0], [1, 1]], dtype=np.int64) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) - # Average over classes, and sum over examples. + # Sum over examples. expected_loss = ( - np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) / n_classes + np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) ) keys = metric_keys.MetricKeys @@ -483,26 +605,55 @@ class MultiLabelHead(test.TestCase): def test_train_create_loss_large_logits(self): """Tests head.create_loss for train mode and large logits.""" n_classes = 2 - head = head_lib.multi_label_head(n_classes) + head = head_lib.multi_label_head(n_classes, weight_column='label_weights') logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) labels = np.array([[1, 0], [1, 1]], dtype=np.int64) + weights = np.array([[1.], [2.]], dtype=np.float32) # loss = labels * -log(sigmoid(logits)) + # (1 - labels) * -log(1 - sigmoid(logits)) # For large logits, this is approximated as: # loss = labels * (logits < 0) * (-logits) + # (1 - labels) * (logits > 0) * logits - expected_unweighted_loss = np.array( - [[10., 10.], [15., 0.]], dtype=np.float32) - actual_unweighted_loss, _ = head.create_loss( - features={'x': np.array(((42,),), dtype=np.int32)}, + expected_weighted_sum_loss = np.sum( + np.array( + [[1. * (10. + 10.) / 2.], [2. * (15. + 0.) / 2.]], + dtype=np.float32)) + expected_example_weight_sum = 1. + 2. + actual_weighted_sum_loss, actual_example_weight_sum, _ = head.create_loss( + features={ + 'x': np.array(((42,),), dtype=np.int32), + 'label_weights': weights + }, mode=model_fn.ModeKeys.TRAIN, logits=logits, labels=labels) with self.test_session(): _initialize_variables(self, monitored_session.Scaffold()) self.assertAllClose( - expected_unweighted_loss, actual_unweighted_loss.eval(), atol=1e-4) + expected_weighted_sum_loss, + actual_weighted_sum_loss.eval(), + atol=1e-4) + self.assertAllClose( + expected_example_weight_sum, + actual_example_weight_sum.eval(), + atol=1e-4) + + def test_train_labels_none(self): + """Tests that error is raised when labels is None.""" + head = head_lib.multi_label_head(n_classes=2) + def _no_op_train_fn(loss): + del loss + return control_flow_ops.no_op() + + with self.assertRaisesRegexp( + ValueError, r'You must provide a labels Tensor\. Given: None\.'): + head.create_estimator_spec( + features={'x': np.array(((42,),), dtype=np.int32)}, + mode=model_fn.ModeKeys.TRAIN, + logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), + labels=None, + train_op_fn=_no_op_train_fn) def _test_train(self, head, logits, labels, expected_loss): expected_train_result = 'my_train_op' diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns.py b/tensorflow/contrib/estimator/python/estimator/logit_fns.py index 110ea0302e703fd3eecdfafea928d7ba04f07d8e..fc5efa4d7b98123ae968f98d4a54900e2d63570d 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns.py @@ -39,6 +39,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import six + from tensorflow.python.estimator import util from tensorflow.python.estimator.canned import dnn as dnn_core from tensorflow.python.estimator.canned import linear as linear_core @@ -67,7 +69,8 @@ def call_logit_fn(logit_fn, features, mode, params, config): A logit Tensor, the output of logit_fn. Raises: - ValueError: if logit_fn does not return a Tensor. + ValueError: if logit_fn does not return a Tensor or a dictionary mapping + strings to Tensors. """ logit_fn_args = util.fn_args(logit_fn) kwargs = {} @@ -79,7 +82,15 @@ def call_logit_fn(logit_fn, features, mode, params, config): kwargs['config'] = config logit_fn_results = logit_fn(features=features, **kwargs) - if not isinstance(logit_fn_results, ops.Tensor): - raise ValueError('model_fn should return a Tensor.') + result_is_valid_dictionary = ( + isinstance(logit_fn_results, dict) and + all([(isinstance(k, str) and isinstance(v, ops.Tensor)) + for k, v in six.iteritems(logit_fn_results)])) + result_is_tensor = isinstance(logit_fn_results, ops.Tensor) + + if not (result_is_valid_dictionary or result_is_tensor): + raise ValueError('logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors. logit_fn returned: %s' % + logit_fn_results) return logit_fn_results diff --git a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py index d75eada798dcdf929e4094258ecdc6ce394f847c..3279e920018bae8ca9520a6372f6b71971da7b52 100644 --- a/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py +++ b/tensorflow/contrib/estimator/python/estimator/logit_fns_test.py @@ -43,22 +43,53 @@ class LogitFnTest(test.TestCase): with session.Session(): self.assertAllClose([[4., 5.]], logit_fn_result.eval()) - def test_should_return_tensor(self): + def test_simple_call_multi_logit_fn(self): + + def dummy_logit_fn(features): + return {'head1': features['f1'], 'head2': features['f2']} + + features = { + 'f1': constant_op.constant([[2., 3.]]), + 'f2': constant_op.constant([[4., 5.]]) + } + logit_fn_result = logit_fns.call_logit_fn(dummy_logit_fn, features, + model_fn.ModeKeys.TRAIN, + 'fake_params', 'fake_config') + with session.Session(): + self.assertAllClose([[2., 3.]], logit_fn_result['head1'].eval()) + self.assertAllClose([[4., 5.]], logit_fn_result['head2'].eval()) + + def test_invalid_logit_fn_results(self): def invalid_logit_fn(features, params): - return { - 'tensor1': features['f1'] * params['input_multiplier'], - 'tensor2': features['f2'] * params['input_multiplier'] - } + return [ + features['f1'] * params['input_multiplier'], + features['f2'] * params['input_multiplier'] + ] + features = { 'f1': constant_op.constant([[2., 3.]]), 'f2': constant_op.constant([[4., 5.]]) } params = {'learning_rate': 0.001, 'input_multiplier': 2.0} - with self.assertRaisesRegexp(ValueError, 'model_fn should return a Tensor'): + with self.assertRaisesRegexp( + ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors'): logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', params, 'fake_config') + def test_invalid_logit_fn_results_dict(self): + + def invalid_logit_fn(features): + return {'head1': features['f1'], 'head2': features['f2']} + + features = {'f1': constant_op.constant([[2., 3.]]), 'f2': 'some string'} + with self.assertRaisesRegexp( + ValueError, 'logit_fn should return a Tensor or a dictionary mapping ' + 'strings to Tensors'): + logit_fns.call_logit_fn(invalid_logit_fn, features, 'fake_mode', + 'fake_params', 'fake_config') + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head.py b/tensorflow/contrib/estimator/python/estimator/multi_head.py index e6340424f741cd0278dbdef41dd4395e98f23246..64b2a9dee83801b5d6d852a3485fc0cc81417ff0 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head.py @@ -236,7 +236,10 @@ class _MultiHead(head_lib._Head): # pylint:disable=protected-access for head, spec in zip(self._heads, all_estimator_spec): head_name = head.name for k, v in six.iteritems(spec.export_outputs): - key = '%s/%s' % (k, head_name) if k else head_name + if k == _DEFAULT_SERVING_KEY: + key = head_name + else: + key = '%s/%s' % (k, head_name) export_outputs[key] = v for k, v in six.iteritems(spec.predictions): predictions[(head_name, k)] = v diff --git a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py index e86cb2b96fe1c10352337367616a0ea2ff9132cc..48027035cecffc3ce8aacf8ae917f5eb9e9b2473 100644 --- a/tensorflow/contrib/estimator/python/estimator/multi_head_test.py +++ b/tensorflow/contrib/estimator/python/estimator/multi_head_test.py @@ -126,8 +126,8 @@ class MultiHeadTest(test.TestCase): logits=logits) self.assertItemsEqual( - (_DEFAULT_SERVING_KEY, _DEFAULT_SERVING_KEY + '/head1', 'head1', - _DEFAULT_SERVING_KEY + '/head2', 'head2'), + (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', + 'head2', 'classification/head2', 'predict/head2'), spec.export_outputs.keys()) # Assert predictions and export_outputs. diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py new file mode 100644 index 0000000000000000000000000000000000000000..7005a647db599dfa386f34406911febe1d9d5651 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py @@ -0,0 +1,470 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities to replicate model_fn's over local GPUs. + +This file contains util that allow to replicate `Estimator.model_fn` over +GPUs. Replicated version of a `model_fn` is returned that can subsequently +be used with `Estimator`. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy + +import six + +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.client import device_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator import util +from tensorflow.python.estimator.export import export_output as export_output_lib +from tensorflow.python.framework import device as framework_device +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients as gradients_lib +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.platform import tf_logging +from tensorflow.python.training import training_util + + +def replicate_model_fn(model_fn, optimizer_fn, devices=None): + """Replicate `Estimator.model_fn` over GPUs within a single host. + + The given `model_fn` specifies a single forward pass of a model. To replicate + such a model over GPUs, each GPU gets its own instance of the forward pass + (a.k.a. a tower). The input features and labels get sharded into the chunks + that correspond to the number of GPUs. Each tower computes its own loss based + on its input. For each such loss, gradients are computed. After that, the + available losses are summed to form aggregated loss. The available + gradients are summed too. Then, they update weights using the specified + optimizer. + + If `devices` are `None`, then all available GPUs are going to be used for + replication. If no GPUs are available, then the model is going to be + placed on the CPU. + + Two modes of local replication over available GPUs are supported: + 1) If exactly 1 GPU is detected, then variables and operations are placed + onto GPU. + 2) If more than 1 GPU is detected, then variables are going to be placed on + the CPU. Replicas of operations are placed on each individual GPU. + + Here is an example of how one might use their `model_fn` to run over GPUs: + ```python + def optimizer_fn(): + return tf.train.GradientDescentOptimizer(learning_rate=0.001) + ... + def model_fn(...): # See `model_fn` in `Estimator`. + loss = ... + if mode == tf.estimator.ModeKeys.TRAIN: + # See the section below on `EstimatorSpec.train_op`. + return EstimatorSpec(mode=mode, loss=loss, train_op=tf.noop()) + + # No change for `ModeKeys.EVAL` or `ModeKeys.PREDICT`. + return EstimatorSpec(...) + ... + classifier = tf.estimator.Estimator( + model_fn=replicate_model_fn.replicate_model_fn(model_fn, optimizer_fn)) + ``` + + On `EstimatorSpec.train_op`: + `model_fn` returns `EstimatorSpec.train_op` for + `tf.estimator.GraphKeys.TRAIN`. It is typically derived using an optimizer. + `replicate_model_fn` ignores the returned `EstimatorSpec.train_op`, so there + is no need to use an optimizer inside the user's `model_fn`. The + `EstimatorSpec.loss` subgraph is going to be executed, while + `EstimatorSpec.train_op` isn't going to be executed. One could pass + `train_op=tf.noop()` to `EstimatorSpec`. + + On sharding input features and labels: + Input features and labels are split for consumption by each tower. They are + split across the dimension 0. Features and labels need to be batch major. + + On reduction algorithms: + Certain algorithms were chosen for aggregating results of computations on + multiple towers: + - Losses from all towers are reduced using sum. + - Gradients are reduced using sum for each trainable variable. + - `eval_metrics_ops` are reduced per metric using `reduce_mean`. + - `EstimatorSpec.predictions` and `EstimatorSpec.export_outputs` are + reduced using concatenation. + - For all other fields of `EstimatorSpec` the values of the first tower + are taken. + + On replication of variables: + Variables are not duplicated between towers. Instead, they are placed on a + single device as defined above and shared across towers. + + Other current limitations: + - `predictions` are not supported for `ModeKeys.EVAL`. That is required for + `tf.contrib.estimator.add_metrics`. + + Args: + model_fn: `model_fn` as defined in `Estimator`. See the section above about + the train_op argument of `EstimatorSpec`. + optimizer_fn: a function that returns an optimizer instance. The function + may accept one `params` argument. This is the `params` argument as + defined by `Estimator`. See the `Estimator` documentation for details. + devices: Optional list of devices to replicate the model across. This + argument can be used to replice only on the subset of available GPUs. + If `None`, then all available GPUs are going to be used for replication. + If no GPUs are available, then the model is going to be placed on the CPU. + + Returns: + A replicated version of the supplied `model_fn`. Returned function that + conforms to the requirements of `Estimator`'s `model_fn` and can be used + instead of the supplied `model_fn`. + """ + if not devices: + devices = _get_local_devices('GPU') or _get_local_devices('CPU') + + is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0] + local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU') + + tf_logging.info('Replicating the `model_fn` across {}. Local parameter ' + 'server device is going to be {}.'.format( + devices, local_ps_device)) + + def replicated_model_fn(mode, features, labels, params=None, config=None): + """Replicated version of `model_fn` to be used instead.""" + feature_shards, label_shards = _split_batch( + features, labels, len(devices), device=local_ps_device) + tower_specs = _get_loss_towers( + model_fn=model_fn, + mode=mode, + features=feature_shards, + labels=label_shards, + params=params, + config=config, + devices=devices, + local_ps_device=local_ps_device) + + if mode == model_fn_lib.ModeKeys.TRAIN: + train_op = _minimize_towers(tower_specs, + _call_optimizer_fn(optimizer_fn, params)) + return _train_spec( + tower_specs, train_op, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.EVAL: + return _eval_spec(tower_specs, aggregation_device=local_ps_device) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return _predict_spec(tower_specs, aggregation_device=local_ps_device) + + return replicated_model_fn + + +def _get_local_devices(device_type): + local_device_protos = device_lib.list_local_devices() + return [ + device.name + for device in local_device_protos + if device.device_type == device_type + ] + + +def _split_batch(features, labels, number_of_shards, device): + """Split input features and labes into batches.""" + + def split_dictionary(dictionary): + shards = [{} for _ in range(number_of_shards)] + for name, tensor in six.iteritems(dictionary): + for i, shard in enumerate(array_ops.split(tensor, number_of_shards)): + shards[i][name] = shard + return shards + + with ops_lib.name_scope('split_inputs'): + with ops_lib.device(device): + if isinstance(features, dict): + feature_shards = split_dictionary(features) + else: + feature_shards = array_ops.split(features, number_of_shards) + + if labels is None: + label_shards = None + elif isinstance(labels, dict): + label_shards = split_dictionary(labels) + else: + label_shards = array_ops.split(labels, number_of_shards) + return feature_shards, label_shards + + +_DEFAULT_NAME_SCOPE_PATTERN = 'tower_{}' + + +def _get_loss_towers(model_fn, + mode, + features, + labels, + params, + config, + devices, + local_ps_device, + name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN): + """Replicate the loss computation across devices.""" + tower_specs = [] + + model_fn_args = util.fn_args(model_fn) + optional_params = {} + if 'params' in model_fn_args: + optional_params['params'] = copy.deepcopy(params) + if 'config' in model_fn_args: + optional_params['config'] = copy.deepcopy(config) + + for i, device in enumerate(devices): + is_the_first_tower = (i == 0) + + device_setter = _local_device_setter( + worker_device=device, ps_device=local_ps_device) + + # We would like to preserve the names of the variables and ops that a user + # might be relying on. Names with prefix are going to resolve to variables + # and ops of the first tower. + name_scope = name_scope_pattern + if is_the_first_tower: + name_scope = '' + + with variable_scope.variable_scope('', reuse=not is_the_first_tower): + with ops_lib.name_scope(name_scope.format(i)): + with ops_lib.device(device_setter): + labels_shard = None + if labels: + labels_shard = labels[i] + + tower_specs.append( + model_fn( + mode=mode, + features=features[i], + labels=labels_shard, + **optional_params)) + return tower_specs + + +def _local_device_setter(ps_device, worker_device): + """A device setter that puts distributes Var/Ops to PS/workers.""" + ps_ops = ['Variable', 'VariableV2', 'VarHandleOp'] + + def local_device_chooser(op): + current_device = framework_device.DeviceSpec.from_string(op.device or '') + + node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def + if node_def.op in ps_ops: + ps_device_spec = framework_device.DeviceSpec.from_string( + '{}'.format(ps_device)) + + ps_device_spec.merge_from(current_device) + return ps_device_spec.to_string() + else: + worker_device_spec = framework_device.DeviceSpec.from_string( + worker_device or '') + worker_device_spec.merge_from(current_device) + return worker_device_spec.to_string() + + return local_device_chooser + + +def _minimize_towers(tower_specs, optimizer): + """Aggregate and apply gradients for computed losses.""" + grad_lists = {} + for tower_spec in tower_specs: + with ops_lib.device(tower_spec.loss.device): + variables = variables_lib.trainable_variables() + gradients = gradients_lib.gradients(tower_spec.loss, variables) + + for var, grad in zip(variables, gradients): + if grad is not None: + grad_lists.setdefault(var, []).append(grad) + + aggregated_grads = [] + with ops_lib.name_scope('gradient_aggregating'): + for var, grads in six.iteritems(grad_lists): + grad = _compute_sum_on_device(grads, var.device) + aggregated_grads.append((grad, var)) + + train_op = optimizer.apply_gradients( + aggregated_grads, global_step=training_util.get_global_step()) + + return train_op + + +def _call_optimizer_fn(optimizer_fn, params): + arguments = {} + optimizer_fn_arguments = util.fn_args(optimizer_fn) + if 'params' in optimizer_fn_arguments: + arguments['params'] = params + return optimizer_fn(**arguments) + + +def _compute_sum_on_device(values, device, name=None): + with ops_lib.device(device): + return math_ops.add_n(values, name=name) + + +def _train_spec(tower_specs, + train_op, + aggregation_device, + aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.TRAIN`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.TRAIN + estimator_spec['train_op'] = train_op + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _eval_spec(tower_specs, aggregation_device, aggregated_loss_name='loss'): + """Populate replicated EstimatorSpec for `GraphKeys.EVAL`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.EVAL + estimator_spec['loss'] = _compute_sum_on_device( + [spec.loss for spec in tower_specs], aggregation_device, + aggregated_loss_name) + + eval_metric_ops_lists = {} + for tower_spec in tower_specs: + metrics = tower_spec.eval_metric_ops or {} + for name, (_, update_op) in six.iteritems(metrics): + update_ops = eval_metric_ops_lists.setdefault(name, ([])) + update_ops.append(update_op) + + eval_metric_ops = {} + for name, (metric_tensor, _) in six.iteritems(tower_specs[0].eval_metric_ops): + with ops_lib.control_dependencies(eval_metric_ops_lists[name]): + # This operation reduces local variables across all metrics, yet is + # called for every metric. This is redundant and it's done because + # it is hard to know what local variables correspond to what metric. + # Estimator is going to execute all `reduced_update_op`s as part of + # a group inside a single `Session.run()` call, which will avoid duplicate + # computation. + reduced_update_op = _reduce_metric_variables(len(tower_specs)) + eval_metric_ops[name] = (metric_tensor, reduced_update_op) + + estimator_spec['eval_metric_ops'] = eval_metric_ops + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _reduce_metric_variables(number_of_towers): + """Aggregate local variables used in metrics into the first tower.""" + if number_of_towers == 1: + return control_flow_ops.no_op() + + metric_variables = ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES) + variables_per_tower = len(metric_variables) // number_of_towers + + if len(metric_variables) % number_of_towers != 0: + raise ValueError( + 'Different `EstimatorSpec.eval_metric_ops` across `model_fn()` calls.' + ' Expected {} local variables, but got {} instead.'.format( + variables_per_tower * number_of_towers, len(metric_variables))) + + # `metric_variables` has the size of `variables_per_tower` x + # number_of_towers. Each tower is produced by calling the same model_fn. + # First `variables_per_tower` correspond to the first tower. Each such + # variable has an replica at the `(variables_per_tower * i)` position, where + # `i` is `[1.. number_of_towers]`. We are going to add values from replicas + # to each variable of the first tower. We then zero out replica values, so + # that `_reduce_metric_variables` operation is idempotent. If a metric + # is then computed based on local variables from the first tower, then the + # resulting metric is an estimate for all `number_of_towers` towers. + ops = [] + for i in range(0, variables_per_tower): + next_replica_id = i + variables_per_tower + replicas = [ + metric_variables[replica_id] + for replica_id in range(next_replica_id, len(metric_variables), + variables_per_tower) + ] # `replicas` doesn't contain the first-tower variable. + + reduce_op = state_ops.assign_add(metric_variables[i], + math_ops.add_n(replicas)) + + with ops_lib.control_dependencies([reduce_op]): + for replica in replicas: + zeros_for_replica = array_ops.zeros( + array_ops.shape(replica), dtype=replica.dtype) + zero_out_replica_op = state_ops.assign(replica, zeros_for_replica) + ops.append(zero_out_replica_op) + + return control_flow_ops.group(*ops) + + +def _predict_spec(tower_specs, aggregation_device): + """Populate replicated EstimatorSpec for `GraphKeys.PREDICT`.""" + estimator_spec = tower_specs[0]._asdict() + estimator_spec['mode'] = model_fn_lib.ModeKeys.PREDICT + + with ops_lib.device(aggregation_device): + estimator_spec['predictions'] = _concat_tensor_dicts( + *[tower_spec.predictions for tower_spec in tower_specs]) + + export_outputs_dict = _dict_concat( + *[tower_spec.export_outputs for tower_spec in tower_specs]) + + export_outputs = {} + for name, export_output_list in six.iteritems(export_outputs_dict): + if isinstance(export_output_list[0], export_output_lib.PredictOutput): + export_outputs[name] = export_output_lib.PredictOutput( + outputs=_concat_tensor_dicts(*[ + export_output.outputs for export_output in export_output_list + ])) + elif isinstance(export_output_list[0], + export_output_lib.RegressionOutput): + export_outputs[name] = export_output_lib.RegressionOutput( + value=array_ops.concat( + [export_output.value for export_output in export_output_list], + axis=0)) + elif isinstance(export_output_list[0], + export_output_lib.ClassificationOutput): + scores = None + if export_output_list[0].scores is not None: + scores = array_ops.concat( + [export_output.scores for export_output in export_output_list], + axis=0) + + classes = None + if export_output_list[0].classes is not None: + classes = array_ops.stack( + [export_output.classes for export_output in export_output_list], + axis=0) + + export_outputs[name] = export_output_lib.ClassificationOutput( + scores=scores, classes=classes) + + estimator_spec['export_outputs'] = export_outputs + return model_fn_lib.EstimatorSpec(**estimator_spec) + + +def _concat_tensor_dicts(*tensor_dicts): + return { + name: array_ops.concat(tensors, axis=0, name=name) + for name, tensors in six.iteritems(_dict_concat(*tensor_dicts)) + } + + +def _dict_concat(*dicts): + list_dict = {} + for d in dicts: + if d is None: + continue + + for k, v in six.iteritems(d): + list_dict.setdefault(k, []).append(v) + return list_dict diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py new file mode 100644 index 0000000000000000000000000000000000000000..10b47fba5af0f2a036df637a4f4f996d388270c6 --- /dev/null +++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py @@ -0,0 +1,901 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utilities that replicate `Estimator.model_fn` over GPUs.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +import shutil +import tempfile +import numpy as np +import six + +from tensorflow.contrib.estimator.python.estimator import replicate_model_fn +from tensorflow.python.estimator import estimator as estimator_lib +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import dnn +from tensorflow.python.estimator.canned import optimizers +from tensorflow.python.estimator.canned import prediction_keys +from tensorflow.python.estimator.export import export +from tensorflow.python.estimator.export import export_output +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.feature_column import feature_column +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as ops_lib +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics as metrics_lib +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.losses import losses +from tensorflow.python.platform import gfile +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import gradient_descent + + +class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def test_complete_flow(self): + n_classes = 3 + input_dimension = 2 + batch_size = 12 + + data = np.linspace( + 0., n_classes - 1., batch_size * input_dimension, dtype=np.float32) + x_data = data.reshape(batch_size, input_dimension) + y_data = np.reshape(self._as_label(data[:batch_size]), (batch_size, 1)) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, + y=y_data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, y=y_data, batch_size=batch_size, shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': x_data}, batch_size=batch_size, shuffle=False) + + feature_columns = [ + feature_column.numeric_column('x', shape=(input_dimension,)) + ] + + estimator = dnn.DNNClassifier( + hidden_units=(2, 2), + feature_columns=feature_columns, + n_classes=n_classes, + model_dir=self._model_dir) + + def optimizer_fn(): + return optimizers.get_optimizer_instance('Adagrad', learning_rate=0.05) + + # TODO(isaprykin): Switch Estimator to use allow_soft_placement=True + # during export_savedmodel and then switch this test to replicate over + # GPUs instead of CPUs. + estimator = estimator_lib.Estimator( + model_fn=replicate_model_fn.replicate_model_fn( + estimator.model_fn, + optimizer_fn, + devices=['/cpu:0', '/cpu:0', '/cpu:0']), + model_dir=estimator.model_dir, + config=estimator.config, + params=estimator.params) + + num_steps = 10 + estimator.train(train_input_fn, steps=num_steps) + + scores = estimator.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops_lib.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + predicted_proba = np.array([ + x[prediction_keys.PredictionKeys.PROBABILITIES] + for x in estimator.predict(predict_input_fn) + ]) + self.assertAllEqual((batch_size, n_classes), predicted_proba.shape) + + feature_spec = feature_column.make_parse_example_spec(feature_columns) + serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn( + feature_spec) + export_dir = estimator.export_savedmodel(tempfile.mkdtemp(), + serving_input_receiver_fn) + self.assertTrue(gfile.Exists(export_dir)) + + def _as_label(self, data_in_float): + return np.rint(data_in_float).astype(np.int64) + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + +class ReplicateModelTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.multiply(features, c) + + loss = None + if mode is not model_fn_lib.ModeKeys.PREDICT: + loss = losses.absolute_difference( + labels=labels, + predictions=predictions, + reduction=losses.Reduction.SUM) + loss = math_ops.reduce_sum(loss) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=loss, + eval_metric_ops=metrics, + predictions={'probabilities': predictions}, + train_op=control_flow_ops.no_op()) # This train_op isn't actually used. + + def optimizer_fn(self, params): + return gradient_descent.GradientDescentOptimizer(params['learning_rate']) + + @property + def params(self): + params = {} + params['learning_rate'] = 1.0 + return params + + def test_train(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_train_spec_with_optimizer_without_params(self): + + def optimizer_fn_without_params(): + return gradient_descent.GradientDescentOptimizer(learning_rate=1.0) + + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: # pylint: disable=unused-variable + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, + optimizer_fn_without_params, + devices=['/gpu:0', '/gpu:1']) + # This call is going to fail if `replicated_model_fn` is still passing + # `params` inside `optimizer_fn`, even though the latter doesn't take any: + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + del estimator_spec + + def test_eval(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + def test_train_single_tower(self): + features = np.array([[1.0], [2.0]]) + labels = np.array([[1.0], [2.0]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.TRAIN, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + # loss = feature * c - label + total_loss = (1.0 * 10 - 1.0) + (2.0 * 10 - 2.0) + self.assertEqual(total_loss, session.run(estimator_spec.loss)) + + # loss' of c is 3. + # new value of c = 10 - learning rate * 3 = 7.0. + session.run(estimator_spec.train_op) + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(7.0, session.run(c)) + + def test_eval_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.EVAL, features, + labels, self.params) + session.run(variables.local_variables_initializer()) + session.run(variables.global_variables_initializer()) + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + # Accuracy is 0.0 (no match) in the first tower. + # Accuracy is 1.0 (match) in the second tower, since the feature + # times weight "c" happened to be equal to the label. + total_loss = ((0.01 * 10 - 0.01) + (0.002 * 10 - 0.02)) + + self.assertNear((0.0 + 1.0) / 2.0, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertNear(total_loss, session.run(estimator_spec.loss), 0.01) + + def test_predict_single_tower(self): + features = np.array([[0.01], [0.002]]) + labels = np.array([[0.01], [0.02]]) + + with self.test_session() as session: + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, self.params) + session.run(variables.global_variables_initializer()) + + self.assertAllClose({ + 'probabilities': np.array([[0.1], [0.02]]) + }, session.run(estimator_spec.predictions)) + + +class GetLossTowersTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c) + labels = np.array([0.1, 0.2, 0.3, labels[0]]) + + loss = losses.absolute_difference( + labels=labels, predictions=predictions, reduction=losses.Reduction.SUM) + + return model_fn_lib.EstimatorSpec(mode=mode, loss=math_ops.reduce_sum(loss)) + + def test_gradients_are_computed(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.6], [1.6]], + labels=[[0.6], [0.6]], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + name_scope_pattern='test_tower_{}') + session.run(variables.global_variables_initializer()) + + self.assertEqual(len(tower_specs), 2) + + self.assertEqual('/device:GPU:0', tower_specs[0].loss.device) + self.assertEqual('Sum:0', tower_specs[0].loss.name) + self.assertEqual(1.0, session.run(tower_specs[0].loss)) + + self.assertEqual('/device:GPU:1', tower_specs[1].loss.device) + self.assertEqual('test_tower_1/Sum:0', tower_specs[1].loss.name) + # The input batch for the second tower had a loss that is 1.0 + # bigger: 0.6 vs 1.6. + self.assertEqual(2.0, session.run(tower_specs[1].loss)) + + self.assertEqual(1, len(variables.global_variables())) + self.assertEqual(1, len(variables.trainable_variables())) + + with variable_scope.variable_scope('', reuse=True): + c = variable_scope.get_variable('c', dtype=dtypes.float64) + self.assertEqual(0.25, session.run(c)) + + +class SplitBatchTest(test_util.TensorFlowTestCase): + + def evaluate_shards(self, first_list, second_list): + evaluate_items = lambda x: x.eval() + return list(map(evaluate_items, first_list)), list( + map(evaluate_items, second_list)) + + def test_simple_half_split(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0], [2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0], [12.0, 13.0]], label_shards) + + def test_to_each_their_own(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 4, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0], [1.0], [2.0], [3.0]], feature_shards) + self.assertAllEqual([[10.0], [11.0], [12.0], [13.0]], label_shards) + + def test_one_batch(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = [0.0, 1.0, 2.0, 3.0] + labels = [10.0, 11.0, 12.0, 13.0] + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + feature_shards, label_shards = self.evaluate_shards( + feature_shards, label_shards) + + self.assertAllEqual([[0.0, 1.0, 2.0, 3.0]], feature_shards) + self.assertAllEqual([[10.0, 11.0, 12.0, 13.0]], label_shards) + + def test_half_split_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0, 11.0], label_shards[0].eval()) + self.assertAllEqual([12.0, 13.0], label_shards[1].eval()) + + def test_one_batch_in_dictionary(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = [10.0, 11.0, 12.0, 13.0] + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 1, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0, 2.0, 3.0], + feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0, 6.0, 7.0], + feature_shards[0]['second'].eval()) + self.assertAllEqual([10.0, 11.0, 12.0, 13.0], label_shards[0].eval()) + + def test_feature_and_label_dictionaries(self): + with self.test_session() as session: # pylint: disable=unused-variable + features = {'first': [0.0, 1.0, 2.0, 3.0], 'second': [4.0, 5.0, 6.0, 7.0]} + labels = {'first': [10.0, 11.0], 'second': [12.0, 13.0]} + + feature_shards, label_shards = replicate_model_fn._split_batch( + features, labels, 2, device='/gpu:0') + + self.assertAllEqual([0.0, 1.0], feature_shards[0]['first'].eval()) + self.assertAllEqual([4.0, 5.0], feature_shards[0]['second'].eval()) + self.assertAllEqual([2.0, 3.0], feature_shards[1]['first'].eval()) + self.assertAllEqual([6.0, 7.0], feature_shards[1]['second'].eval()) + self.assertAllEqual([10.0], label_shards[0]['first'].eval()) + self.assertAllEqual([12.0], label_shards[0]['second'].eval()) + self.assertAllEqual([11], label_shards[1]['first'].eval()) + self.assertAllEqual([13.0], label_shards[1]['second'].eval()) + + +class TrainSpecTest(test_util.TensorFlowTestCase): + + expected_predictions = {} + + def create_estimator_spec(self, loss): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.TRAIN, + loss=loss, + train_op=loss, # Not used; currently required. + predictions=self.expected_predictions) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def test_example(self): + with self.test_session() as session: + tower_losses = list(map(self.create_constant_loss, [2, 4, 6])) + tower_specs = list(map(self.create_estimator_spec, tower_losses)) + + expected_train_op = tower_losses[1] + + estimator_spec = replicate_model_fn._train_spec( + tower_specs, expected_train_op, aggregation_device='/gpu:0') + + self.assertEqual(expected_train_op, estimator_spec.train_op) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + self.assertEqual(self.expected_predictions, estimator_spec.predictions) + + +class EvalSpecTest(test_util.TensorFlowTestCase): + + def create_estimator_spec(self, loss, metrics): + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, loss=loss, eval_metric_ops=metrics) + + def create_constant_loss(self, loss_value): + return constant_op.constant(loss_value, dtype=dtypes.float64) + + def create_eval_metrics(self, noise): + predictions = np.array([0.1, 0.2, 0.3, 0.6 + noise]) + labels = np.array([0.1, 0.2, 0.3, 0.6]) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions), + 'auc': metrics_lib.auc(labels, predictions) + } + return metrics + + def test_example(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [2, 4, 6]) + tower_metrics = map(self.create_eval_metrics, [0, 0.2, 0.3]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((12 - 2) / 12, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(2 + 4 + 6, session.run(estimator_spec.loss)) + + def test_handles_single_tower(self): + with self.test_session() as session: + tower_losses = map(self.create_constant_loss, [5]) + tower_metrics = map(self.create_eval_metrics, [0.2]) + tower_specs = [ + self.create_estimator_spec(l, m) + for l, m in zip(tower_losses, tower_metrics) + ] + session.run(variables.local_variables_initializer()) + + estimator_spec = replicate_model_fn._eval_spec( + tower_specs, aggregation_device='/device:GPU:0') + + accuracy, a = estimator_spec.eval_metric_ops['accuracy'] + auc, b = estimator_spec.eval_metric_ops['auc'] + + self.assertEqual('/device:CPU:0', accuracy.device) + self.assertEqual('/device:CPU:0', auc.device) + + session.run([a, b]) + accuracy = session.run(accuracy) + auc = session.run(auc) + + self.assertNear((4 - 1) / 4, accuracy, 0.01) + self.assertEqual(0, auc) + self.assertEqual(5, session.run(estimator_spec.loss)) + + +class PredictSpecTest(test_util.TensorFlowTestCase): + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(0.25, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = math_ops.add(np.array([features[0], features[0]]), c) + + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions={ + 'probabilities': predictions + }) + + def test_example(self): + with self.test_session() as session: + tower_specs = replicate_model_fn._get_loss_towers( + self.model_fn, + mode=None, + features=[[0.1], [0.2]], + labels=[[], []], + params=None, + config=None, + devices=['/gpu:0', '/gpu:1'], + local_ps_device='/gpu:0', + ) + session.run(variables.global_variables_initializer()) + + estimator_spec = replicate_model_fn._predict_spec( + tower_specs, aggregation_device='/gpu:0') + + self.assertEqual('/device:GPU:0', + estimator_spec.predictions['probabilities'].device) + self.assertAllClose({ + 'probabilities': np.array([0.35, 0.35, 0.45, 0.45]) + }, session.run(estimator_spec.predictions)) + + +class ReduceMetricVariablesTest(test_util.TensorFlowTestCase): + + def create_metric_variable(self, initial_value, name): + return variable_scope.variable( + initial_value, + trainable=False, + collections=[ops_lib.GraphKeys.METRIC_VARIABLES], + validate_shape=True, + name=name) + + def create_tower_metrics(self, tower_id): + with variable_scope.variable_scope('', reuse=(tower_id != 0)): + self.create_metric_variable(1.3 * (tower_id + 1), 'total') + self.create_metric_variable(2.3 * (tower_id + 1), 'count') + self.create_metric_variable( + np.array([3.3, 3.5, 3.7]) * (tower_id + 1), 'total') + + def test_example(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + # 1st tower = 1.3, 2.3, [3.3, 3.5, 3.7] + # 2nd tower = 2.6, 4.6, [6.6, 7.0, 7.4] + # 3rd tower = 3.9, 6.9, [9.9, 10.5, 11.1] + # Reduced = 7.8, 13.8, [19.8, 21.0, 22.2] + # Towers are accumulated in the first tower. + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_reduce_is_idempotent(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + for _ in range(20): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(7.8, local_metrics[0], 0.01) + self.assertNear(13.8, local_metrics[1], 0.01) + self.assertAllClose([19.8, 21., 22.1], local_metrics[2], 0.01) + self.assertNear(0.0, local_metrics[3], 0.01) + self.assertNear(0.0, local_metrics[4], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[5], 0.01) + self.assertNear(0.0, local_metrics[6], 0.01) + self.assertNear(0.0, local_metrics[7], 0.01) + self.assertAllClose([0.0, 0.0, 0.0], local_metrics[8], 0.01) + + def test_handles_single_tower(self): + with self.test_session() as session: + self.create_tower_metrics(0) + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=1)) + + local_metrics = session.run( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES)) + + self.assertNear(1.3, local_metrics[0], 0.01) + self.assertNear(2.3, local_metrics[1], 0.01) + self.assertAllClose([3.3, 3.5, 3.7], local_metrics[2], 0.01) + + def test_doesnt_accept_uneven_number_of_variables(self): + with self.test_session() as session: + for tower_id in range(3): + self.create_tower_metrics(tower_id) + self.create_metric_variable(-1.0, 'oddball') + + session.run( + variables.variables_initializer( + ops_lib.get_collection(ops_lib.GraphKeys.METRIC_VARIABLES))) + + with self.assertRaisesRegexp(ValueError, ''): + session.run( + replicate_model_fn._reduce_metric_variables(number_of_towers=3)) + + +class MergeExportOutputsTest(test_util.TensorFlowTestCase): + + def optimizer_fn(self): + return gradient_descent.GradientDescentOptimizer(1.0) + + def model_fn(self, mode, features, labels, params): + c = variable_scope.get_variable( + 'c', + initializer=constant_op.constant(10, dtype=dtypes.float64), + dtype=dtypes.float64) + + predictions = {'probabilities': math_ops.multiply(features, c)} + loss = losses.absolute_difference( + labels=labels, + predictions=predictions['probabilities'], + reduction=losses.Reduction.SUM) + + metrics = { + 'accuracy': metrics_lib.accuracy(labels, predictions['probabilities']), + 'auc': metrics_lib.auc(labels, predictions['probabilities']) + } + tensor_string_repr = str(features) + classes = constant_op.constant( + re.search('(split_inputs/split:[0-9])', tensor_string_repr).group(1), + dtype=dtypes.string) + + export_outputs = { + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: + export_output.PredictOutput(predictions), + 'classification_output': + export_output.ClassificationOutput(predictions['probabilities'], + classes), + 'classification_scores': + export_output.ClassificationOutput( + scores=predictions['probabilities']), + 'classification_classes': + export_output.ClassificationOutput(classes=classes), + 'regression_output': + export_output.RegressionOutput(predictions['probabilities']), + } + + return model_fn_lib.EstimatorSpec( + mode=mode, + loss=math_ops.reduce_sum(loss), + eval_metric_ops=metrics, + predictions=predictions, + train_op=loss, # This train_op isn't actually used. + export_outputs=export_outputs) + + def replicate_estimator_spec(self, session): + features = np.array([0.01, 0.002]) + labels = np.array([0.01, 0.02]) + + replicated_model_fn = replicate_model_fn.replicate_model_fn( + self.model_fn, self.optimizer_fn, devices=['/gpu:0', '/gpu:1']) + estimator_spec = replicated_model_fn(model_fn_lib.ModeKeys.PREDICT, + features, labels, {}) + session.run(variables.global_variables_initializer()) + return estimator_spec + + def test_merde_predict_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + { + 'probabilities': np.array([0.1, 0.02]) + }, + session.run(estimator_spec.export_outputs[ + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY].outputs)) + + def test_merge_classification_output_scores_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_output'].scores)) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_output'].classes)) + + def test_merge_classification_output_scores(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run( + estimator_spec.export_outputs['classification_scores'].scores)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_scores'].classes) + + def test_merge_classification_output_classes(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllEqual( + [b'split_inputs/split:0', b'split_inputs/split:1'], + session.run( + estimator_spec.export_outputs['classification_classes'].classes)) + self.assertEqual( + None, estimator_spec.export_outputs['classification_classes'].scores) + + def test_merge_regression_output(self): + with self.test_session() as session: + estimator_spec = self.replicate_estimator_spec(session) + self.assertAllClose( + [0.1, 0.02], + session.run(estimator_spec.export_outputs['regression_output'].value)) + + +class GetLocalDevicesTest(test_util.TensorFlowTestCase): + + def test_there_is_at_least_a_cpu(self): + self.assertTrue(replicate_model_fn._get_local_devices('CPU')) + + def test_there_is_no_xpu(self): + self.assertFalse( + replicate_model_fn._get_local_devices('XPU')) # XPU doesn't exist. + + def test_whether_there_is_a_gpu(self): + self.assertEqual( + len(replicate_model_fn._get_local_devices('GPU')), + test.is_gpu_available()) + + +class LocalDeviceSetterTest(test_util.TensorFlowTestCase): + + def test_vars_are_on_ps_but_ops_are_on_workers(self): + local_device_setter = replicate_model_fn._local_device_setter( + ps_device='/device:GPU:3', worker_device='/device:GPU:2') + + with ops_lib.device(local_device_setter): + c = variables.Variable(0.01) + self.assertEqual('/device:GPU:3', c.device) + + cc = variables.Variable(0.02) + self.assertEqual('/device:GPU:3', cc.device) + + ccc = variables.Variable(0.03) + self.assertEqual('/device:GPU:3', ccc.device) + + c_op = array_ops.concat(c, axis=0) + self.assertEqual('/device:GPU:2', c_op.device) + + cc_op = array_ops.concat(cc, axis=0) + self.assertEqual('/device:GPU:2', cc_op.device) + + +class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase): + + def test_example(self): + with self.test_session() as session: + total = replicate_model_fn._compute_sum_on_device( + [1.0, 2.0, 3.0, 4.0], device='/device:GPU:0', name='test_sum') + + self.assertEqual('/device:GPU:0', total.device) + self.assertEqual('test_sum', total.op.name) + self.assertEqual(10.0, session.run(total)) + + +class ConcatTensorDictsTest(test_util.TensorFlowTestCase): + + def test_example(self): + tensor_dicts = [ + { + 'a': np.array([1.0, 2.0]), + 'b': np.array([11.0]), + 'c': np.array([21.0]), + }, + { + 'a': np.array([3.0]), + 'b': np.array([12.0, 13.0]), + }, + { + 'b': np.array([14.0]), + }, + ] + + with self.test_session() as session: + self.assertAllClose({ + 'a': np.array([1.0, 2.0, 3.0]), + 'b': np.array([11.0, 12.0, 13.0, 14.0]), + 'c': np.array([21.0]), + }, session.run(replicate_model_fn._concat_tensor_dicts(*tensor_dicts))) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD index c468c544d372e8bfd6adfa49a58e9bf6c5ef0a8b..fe86a20ab1f69a0eaf9d7486142451dac6337274 100644 --- a/tensorflow/contrib/factorization/BUILD +++ b/tensorflow/contrib/factorization/BUILD @@ -8,6 +8,7 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) +load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -23,6 +24,7 @@ tf_custom_op_py_library( "python/ops/factorization_ops.py", "python/ops/gmm.py", "python/ops/gmm_ops.py", + "python/ops/kmeans.py", "python/ops/wals.py", ], dso = [ @@ -48,15 +50,22 @@ tf_custom_op_py_library( "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:linalg_ops", + "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:nn", "//tensorflow/python:platform", "//tensorflow/python:random_ops", "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:summary", + "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", "//third_party/py/numpy", ], ) @@ -131,12 +140,17 @@ tf_py_test( ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", "//third_party/py/numpy", "//tensorflow/contrib/learn", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", ], tags = [ "no_pip", # b/38283730 @@ -160,6 +174,7 @@ tf_py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform", "//tensorflow/python:platform_test", + "//tensorflow/python:random_seed", "//tensorflow/python:variables", ], tags = ["notsan"], # b/62863147 @@ -191,14 +206,41 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_tensor", ], ) # Estimators tests +py_test( + name = "kmeans_test", + size = "medium", + srcs = ["python/ops/kmeans_test.py"], + srcs_version = "PY2AND3", + tags = ["notsan"], # b/67512932 + deps = [ + ":factorization_py", + ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:platform_benchmark", + "//tensorflow/python:random_ops", + "//tensorflow/python:training", + "//tensorflow/python/estimator:run_config", + "//third_party/py/numpy", + ], +) + tf_py_test( name = "wals_test", size = "large", @@ -208,20 +250,26 @@ tf_py_test( ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", ":factorization_ops_test_utils_py", "//third_party/py/numpy", + "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:platform_benchmark", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", + "//tensorflow/python:training", "//tensorflow/python:variables", ], tags = [ "manual", "noasan", # times out b/63678675 "nomsan", - "notsan", ], ) @@ -232,11 +280,13 @@ tf_py_test( additional_deps = [ ":factorization_py", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", + ":gen_factorization_ops", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_tensor", ], ) @@ -260,10 +310,15 @@ tf_py_test( ":gen_factorization_ops", ":factorization_py_CYCLIC_DEPENDENCIES_THAT_NEED_TO_GO", "//third_party/py/numpy", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", + "//tensorflow/python:sparse_tensor", ], ) diff --git a/tensorflow/contrib/factorization/__init__.py b/tensorflow/contrib/factorization/__init__.py index 486c2ea9336d19fb7273d02502f9865adc6aefed..6112c9d8300fe219c8e172a5b70e4ce4cad04eb6 100644 --- a/tensorflow/contrib/factorization/__init__.py +++ b/tensorflow/contrib/factorization/__init__.py @@ -23,22 +23,24 @@ from tensorflow.contrib.factorization.python.ops.clustering_ops import * from tensorflow.contrib.factorization.python.ops.factorization_ops import * from tensorflow.contrib.factorization.python.ops.gmm import * from tensorflow.contrib.factorization.python.ops.gmm_ops import * +from tensorflow.contrib.factorization.python.ops.kmeans import * from tensorflow.contrib.factorization.python.ops.wals import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ - 'KMeans', 'COSINE_DISTANCE', - 'KMEANS_PLUS_PLUS_INIT', - 'RANDOM_INIT', - 'SQUARED_EUCLIDEAN_DISTANCE', - 'WALSModel', 'GMM', 'gmm', 'GmmAlgorithm', + 'KMeans', + 'KMEANS_PLUS_PLUS_INIT', + 'KMeansClustering', + 'RANDOM_INIT', + 'SQUARED_EUCLIDEAN_DISTANCE', 'WALSMatrixFactorization', + 'WALSModel', ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/factorization/g3doc/kmeans.md b/tensorflow/contrib/factorization/g3doc/kmeans.md index b55c9d09ad386b84623d3648c5be83cbba8bbff9..c1843f0bf0704503d43c28d186dc826f0677711f 100644 --- a/tensorflow/contrib/factorization/g3doc/kmeans.md +++ b/tensorflow/contrib/factorization/g3doc/kmeans.md @@ -24,7 +24,11 @@ the full-batch version. approach for computing the initial cluster assignments that is expensive but is typically less prone to getting stuck in bad local minima. -We provide distributed implementations of both full-batch and mini-batch -K-Means algorithm. Both K-Means++ and random initialization are supported. -The user can also choose between **Cosine** and **Squared Euclidean** distance -metrics. +**[k-MC2](https://www.aaai.org/ocs/index.php/AAAI/AAAI16/paper/view/12147/11759)** +provides a very fast seeding method that provides high quality centers +comparable to K-Means++ seeding. k-MC2 works particularly well if it is combined +with Mini-batch K-Means. + +We provide distributed implementations of both full-batch and mini-batch K-Means +algorithm. K-Means++, k-MC2 and random initialization are supported. The user +can also choose between **Cosine** and **Squared Euclidean** distance metrics. diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc index a2136c08bbc2e91f4587b1cdacbfe3b1d1073949..dd61f59585aee2e0245cfd6797b313b972c19bc5 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc @@ -224,6 +224,58 @@ class KmeansPlusPlusInitializationOp : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KmeansPlusPlusInitialization").Device(DEVICE_CPU), KmeansPlusPlusInitializationOp); +// Implementation of one single Markov Chain for the k-MC^2 algorithm +class KMC2ChainInitializationOp : public OpKernel { + public: + explicit KMC2ChainInitializationOp(OpKernelConstruction* context) + : OpKernel(context) { + OP_REQUIRES_OK(context, + context->MatchSignature({DT_FLOAT, DT_INT64}, {DT_INT64})); + } + + void Compute(OpKernelContext* context) override { + const Tensor& distances_tensor = context->input(0); + const Tensor& seed_tensor = context->input(1); + OP_REQUIRES(context, TensorShapeUtils::IsVector(distances_tensor.shape()), + InvalidArgument("Input distances should be a vector.")); + OP_REQUIRES(context, TensorShapeUtils::IsScalar(seed_tensor.shape()), + InvalidArgument("Input seed should be a scalar.")); + const int64 num_points = distances_tensor.dim_size(0); + const int64 seed = seed_tensor.scalar()(); + OP_REQUIRES(context, num_points > 0, + InvalidArgument("Expected distances_tensor.size() > 0.")); + + random::PhiloxRandom random(seed); + random::SimplePhilox rng(&random); + + auto distances = distances_tensor.flat(); + // Set the initial state of the Markov chain to be the first candidate. + int64 selected_index = 0; + float selected_distance = distances(selected_index); + // Build a Markov chain of length num_points. + for (int64 i = 1; i < num_points; ++i) { + const float candidate_distance = distances(i); + // Set the next state of the Markov chain to be the candidate with + // probability min(1, candidate_distance/selected_distance). + if (candidate_distance > rng.RandFloat() * selected_distance) { + selected_index = i; + selected_distance = candidate_distance; + } + } + + Tensor* output_sampled_index_tensor; + OP_REQUIRES_OK(context, + context->allocate_output(0, TensorShape({}), + &output_sampled_index_tensor)); + auto output = output_sampled_index_tensor->scalar(); + // Return the last state of the Markov chain as the new center. + output() = selected_index; + } +}; + +REGISTER_KERNEL_BUILDER(Name("KMC2ChainInitialization").Device(DEVICE_CPU), + KMC2ChainInitializationOp); + // Operator for computing the nearest neighbors for a set of points. class NearestNeighborsOp : public OpKernel { public: diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc index c4a96b048db878169acc69b4d8caed5d4e04c18f..8172a7cebb81de70c530dbdd9ce0ca3eda4bc2ce 100644 --- a/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc +++ b/tensorflow/contrib/factorization/kernels/clustering_ops_test.cc @@ -116,6 +116,62 @@ RUN_BM_KmeansPlusPlusInitialization(k3RetriesPerSample); #undef RUN_BM_KmeansPlusPlusInitialization #undef BENCHMARK_KMEANS_PLUS_PLUS +Graph* SetUpKMC2Initialization(int num_points) { + Graph* g = new Graph(OpRegistry::Global()); + Tensor distances(DT_FLOAT, TensorShape({num_points})); + Tensor seed(DT_INT64, TensorShape({})); + distances.flat().setRandom(); + seed.flat().setConstant(12345); + + TF_CHECK_OK( + NodeBuilder("KMC2ChainInitializationOp", "KMC2ChainInitialization") + .Input(test::graph::Constant(g, distances)) + .Input(test::graph::Constant(g, seed)) + .Finalize(g, nullptr /* node */)); + return g; +} + +template +void BM_KMC2Initialization(int iters) { + testing::StopTiming(); + testing::ItemsProcessed(static_cast(iters) * num_points * num_dims * + num_to_sample); + testing::UseRealTime(); + Graph* g = SetUpKMC2Initialization(num_points); + testing::StartTiming(); + test::Benchmark("cpu", g).Run(iters); +} +#define BENCHMARK_KMC2(p, c, d) \ + void BM_KMC2Initialization_##p##_##c##_##d(int iters) { \ + BM_KMC2Initialization(iters); \ + } \ + BENCHMARK(BM_KMC2Initialization_##p##_##c##_##d); + +#define RUN_BM_KMC2Initialization \ + BENCHMARK_KMC2(k10Points, k2Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k5Centers, k100Dim); \ + BENCHMARK_KMC2(k10Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k10Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k20Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k50Centers, k100Dim); \ + BENCHMARK_KMC2(k100Points, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k10kPoints, k1kCenters, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k100Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k200Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k500Centers, k100Dim); \ + BENCHMARK_KMC2(k1MPoints, k1kCenters, k100Dim) + +RUN_BM_KMC2Initialization; +#undef RUN_BM_KMC2Initialization +#undef BENCHMARK_KMC2 + Graph* SetUpNearestNeighbors(int num_dims, int num_points, int num_centers, int k) { Graph* g = new Graph(OpRegistry::Global()); diff --git a/tensorflow/contrib/factorization/ops/clustering_ops.cc b/tensorflow/contrib/factorization/ops/clustering_ops.cc index f2dfcf7ed0fb05264b10dee9980a246a5f2e49fa..2686702c1d5768f661dac610c96089eb02e360d7 100644 --- a/tensorflow/contrib/factorization/ops/clustering_ops.cc +++ b/tensorflow/contrib/factorization/ops/clustering_ops.cc @@ -44,6 +44,25 @@ num_retries_per_sample: Scalar. For each row that is sampled, this parameter samples: Matrix of shape (num_to_sample, d). The sampled rows. )"); +REGISTER_OP("KMC2ChainInitialization") + .Input("distances: float32") + .Input("seed: int64") + .Output("index: int64") + .SetShapeFn(shape_inference::ScalarShape) + .Doc(R"( +Returns the index of a data point that should be added to the seed set. + +Entries in distances are assumed to be squared distances of candidate points to +the already sampled centers in the seed set. The op constructs one Markov chain +of the k-MC^2 algorithm and returns the index of one candidate point to be added +as an additional cluster center. + +distances: Vector with squared distances to the closest previously sampled + cluster center for each candidate point. +seed: Scalar. Seed for initializing the random number generator. +index: Scalar with the index of the sampled point. +)"); + REGISTER_OP("NearestNeighbors") .Input("points: float32") .Input("centers: float32") diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py index 450f64063a2a357e422cd14761864d511c0e6cce..1322f7ce5f83d82c76040a30699137cd2bf491b5 100644 --- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py +++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py @@ -55,6 +55,63 @@ class KmeansPlusPlusInitializationTest(test.TestCase): self.runTestWithSeed(seed) +class KMC2InitializationTest(test.TestCase): + + def runTestWithSeed(self, seed): + with self.test_session(): + distances = np.zeros(1000).astype(np.float32) + distances[6] = 10e7 + distances[4] = 10e3 + + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 6) + distances[6] = 0.0 + sampled_point = clustering_ops.kmc2_chain_initialization(distances, seed) + self.assertEquals(sampled_point.eval(), 4) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + +class KMC2InitializationLargeTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(1001) + self._distances[500] = 100.0 + self._distances[1000] = 50.0 + + def testBasic(self): + with self.test_session(): + counts = {} + seed = 0 + for i in range(50): + sample = clustering_ops.kmc2_chain_initialization( + self._distances, seed + i).eval() + counts[sample] = counts.get(sample, 0) + 1 + self.assertEquals(len(counts), 2) + self.assertTrue(500 in counts) + self.assertTrue(1000 in counts) + self.assertGreaterEqual(counts[500], 5) + self.assertGreaterEqual(counts[1000], 5) + + +class KMC2InitializationCornercaseTest(test.TestCase): + + def setUp(self): + self._distances = np.zeros(10) + + def runTestWithSeed(self, seed): + with self.test_session(): + sampled_point = clustering_ops.kmc2_chain_initialization( + self._distances, seed) + self.assertEquals(sampled_point.eval(), 0) + + def testBasic(self): + for seed in range(100): + self.runTestWithSeed(seed) + + # A simple test that can be verified by hand. class NearestCentersTest(test.TestCase): diff --git a/tensorflow/contrib/factorization/python/ops/clustering_ops.py b/tensorflow/contrib/factorization/python/ops/clustering_ops.py index ac2fbcceaa48e97d8be3ec2af30cdd8222993aaa..96cc80ce241347ebca5b68140f1b1c8b9898ae72 100644 --- a/tensorflow/contrib/factorization/python/ops/clustering_ops.py +++ b/tensorflow/contrib/factorization/python/ops/clustering_ops.py @@ -50,6 +50,10 @@ COSINE_DISTANCE = 'cosine' RANDOM_INIT = 'random' KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus' +KMC2_INIT = 'kmc2' + +# The name of the variable holding the cluster centers. Used by the Estimator. +CLUSTERS_VAR_NAME = 'clusters' class KMeans(object): @@ -63,7 +67,8 @@ class KMeans(object): use_mini_batch=False, mini_batch_steps_per_iteration=1, random_seed=0, - kmeans_plus_plus_num_retries=2): + kmeans_plus_plus_num_retries=2, + kmc2_chain_length=200): """Creates an object for generating KMeans clustering graph. This class implements the following variants of K-means algorithm: @@ -92,7 +97,8 @@ class KMeans(object): exactly like a full-batch version. Args: - inputs: An input tensor or list of input tensors + inputs: An input tensor or list of input tensors. It is assumed that the + data points have been previously randomly permuted. num_clusters: An integer tensor specifying the number of clusters. This argument is ignored if initial_clusters is a tensor or numpy array. initial_clusters: Specifies the clusters used during initialization. One @@ -101,6 +107,7 @@ class KMeans(object): - a function f(inputs, k) that returns up to k centers from `inputs`. - "random": Choose centers randomly from `inputs`. - "kmeans_plus_plus": Use kmeans++ to choose centers from `inputs`. + - "kmc2": Use the fast k-MC2 algorithm to choose centers from `inputs`. In the last three cases, one batch of `inputs` may not yield `num_clusters` centers, in which case initialization will require multiple batches until enough centers are chosen. In the case of @@ -118,13 +125,17 @@ class KMeans(object): additional points to draw from the current distribution before selecting the best. If a negative value is specified, a heuristic is used to sample O(log(num_to_sample)) additional points. + kmc2_chain_length: Determines how many candidate points are used by the + k-MC2 algorithm to produce one new cluster centers. If a (mini-)batch + contains less points, one new cluster center is generated from the + (mini-)batch. Raises: ValueError: An invalid argument was passed to initial_clusters or distance_metric. """ if isinstance(initial_clusters, str) and initial_clusters not in [ - RANDOM_INIT, KMEANS_PLUS_PLUS_INIT + RANDOM_INIT, KMEANS_PLUS_PLUS_INIT, KMC2_INIT ]: raise ValueError( "Unsupported initialization algorithm '%s'" % initial_clusters) @@ -138,6 +149,7 @@ class KMeans(object): self._mini_batch_steps_per_iteration = int(mini_batch_steps_per_iteration) self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length @classmethod def _distance_graph(cls, inputs, clusters, distance_metric): @@ -279,7 +291,7 @@ class KMeans(object): """ init_value = array_ops.constant([], dtype=dtypes.float32) cluster_centers = variable_scope.variable( - init_value, name='clusters', validate_shape=False) + init_value, name=CLUSTERS_VAR_NAME, validate_shape=False) cluster_centers_initialized = variable_scope.variable( False, dtype=dtypes.bool, name='initialized') @@ -299,9 +311,10 @@ class KMeans(object): else: cluster_centers_updated = cluster_centers update_in_steps = None - cluster_counts = (variable_scope.variable( - array_ops.ones([num_clusters], dtype=dtypes.int64)) - if self._use_mini_batch else None) + cluster_counts = ( + variable_scope.variable( + array_ops.ones([num_clusters], dtype=dtypes.int64)) + if self._use_mini_batch else None) return (cluster_centers, cluster_centers_initialized, cluster_counts, cluster_centers_updated, update_in_steps) @@ -356,7 +369,7 @@ class KMeans(object): init_op = _InitializeClustersOpFactory( self._inputs, num_clusters, initial_clusters, self._distance_metric, self._random_seed, self._kmeans_plus_plus_num_retries, - cluster_centers_var, cluster_centers_updated, + self._kmc2_chain_length, cluster_centers_var, cluster_centers_updated, cluster_centers_initialized).op() cluster_centers = cluster_centers_var @@ -517,8 +530,9 @@ class KMeans(object): array_ops.reshape(array_ops.shape(inp)[0], [-1])), [-1, 1]), cluster_idx, num_clusters)) with ops.colocate_with(cluster_centers, ignore_existing=True): - new_clusters_centers = math_ops.add_n(cluster_sums) / (math_ops.cast( - math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + epsilon) + new_clusters_centers = math_ops.add_n(cluster_sums) / ( + math_ops.cast(math_ops.add_n(cluster_counts), cluster_sums[0].dtype) + + epsilon) if self._clusters_l2_normalized(): new_clusters_centers = nn_impl.l2_normalize(new_clusters_centers, dim=1) return state_ops.assign(cluster_centers, new_clusters_centers) @@ -545,9 +559,12 @@ class _InitializeClustersOpFactory(object): cluster_centers_initialized := true """ + # TODO(ccolby): Refactor this class so that kmc2 isn't so much a special case. + def __init__(self, inputs, num_clusters, initial_clusters, distance_metric, - random_seed, kmeans_plus_plus_num_retries, cluster_centers, - cluster_centers_updated, cluster_centers_initialized): + random_seed, kmeans_plus_plus_num_retries, kmc2_chain_length, + cluster_centers, cluster_centers_updated, + cluster_centers_initialized): """Creates an op factory. Args: @@ -557,6 +574,7 @@ class _InitializeClustersOpFactory(object): distance_metric: See KMeans constructor. random_seed: See KMeans constructor. kmeans_plus_plus_num_retries: See KMeans constructor. + kmc2_chain_length: See KMeans constructor. cluster_centers: The TF variable holding the initial centers. It may already contain some centers when the op is executed. cluster_centers_updated: A second TF variable to hold a copy of the @@ -572,6 +590,7 @@ class _InitializeClustersOpFactory(object): self._distance_metric = distance_metric self._random_seed = random_seed self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._kmc2_chain_length = kmc2_chain_length self._cluster_centers = cluster_centers self._cluster_centers_updated = cluster_centers_updated self._cluster_centers_initialized = cluster_centers_initialized @@ -601,6 +620,90 @@ class _InitializeClustersOpFactory(object): math_ops.to_int64(self._num_remaining), self._random_seed, self._kmeans_plus_plus_num_retries) + def _kmc2_multiple_centers(self): + """Adds new initial cluster centers using the k-MC2 algorithm. + + In each call to the op, the provided batch is split into subsets based on + the specified `kmc2_chain_length`. On each subset, a single Markov chain of + the k-MC2 algorithm is used to add *one* new center cluster center. If there + are less than `kmc2_chain_length` points in the subset, a single center is + added using one Markov chain on the full input. It is assumed that the + provided batch has previously been randomly permuted. Otherwise, k-MC2 may + return suboptimal centers. + + Returns: + An op that adds new cluster centers. + """ + # The op only operates on the first shard of data. + first_shard = self._inputs[0] + # Number of points in the input that can be used. + batch_size = array_ops.shape(first_shard)[0] + # Maximum number of subsets such that the size of each subset is at least + # `kmc2_chain_length`. Final subsets may be larger. + max_to_sample = math_ops.cast( + batch_size / self._kmc2_chain_length, dtype=dtypes.int32) + # We sample at least one new center and at most all remaining centers. + num_to_sample = math_ops.maximum( + math_ops.minimum(self._num_remaining, max_to_sample), 1) + + def _cond(i, _): + """Stopping condition for the while loop.""" + return math_ops.less(i, num_to_sample) + + def _body(i, _): + """Body that adds a single new center based on a subset.""" + + def _sample_random(): + """Returns a random point as a cluster center.""" + # By assumption the batch is reshuffled and _sample_random is always + # called for i=0. Hence, we simply return the first point. + new_center = array_ops.reshape(first_shard[0], [1, -1]) + if self._distance_metric == COSINE_DISTANCE: + new_center = nn_impl.l2_normalize(new_center, dim=1) + return new_center + + def _sample_kmc2_chain(): + """Returns previous centers as well as a new center sampled using k-MC2. + """ + # Extract the subset from the underlying batch. + start = i * self._kmc2_chain_length + end = start + self._kmc2_chain_length + subset = first_shard[start:end] + # Compute the distances from points in the subset to previous centers. + _, distances = gen_clustering_ops.nearest_neighbors( + subset, self._cluster_centers, 1) + # Sample index of new center using k-MC2 Markov chain. + new_center_index = gen_clustering_ops.kmc2_chain_initialization( + array_ops.squeeze(distances), self._random_seed) + # Extract actual new center. + newly_sampled_center = array_ops.reshape(subset[new_center_index], + [1, -1]) + # Return concatenation with previously sampled centers. + if self._distance_metric == COSINE_DISTANCE: + newly_sampled_center = nn_impl.l2_normalize( + newly_sampled_center, dim=1) + return array_ops.concat([self._cluster_centers, newly_sampled_center], + 0) + + # Obtain a random point if there are no previously sampled centers. + # Otherwise, construct a k-MC2 Markov chain. + new_centers = control_flow_ops.cond( + math_ops.equal(self._num_selected, 0), _sample_random, + _sample_kmc2_chain) + # Assign new cluster centers to underlying variable. + assigned_centers = state_ops.assign( + self._cluster_centers, new_centers, validate_shape=False) + if self._cluster_centers_updated is not self._cluster_centers: + assigned_centers = state_ops.assign( + self._cluster_centers_updated, + assigned_centers, + validate_shape=False) + return i + 1, self._num_clusters - array_ops.shape(assigned_centers)[0] + + # Add num_to_sample new data points. + _, num_remaining = control_flow_ops.while_loop(_cond, _body, [0, 0]) + return num_remaining + def _greedy_batch_sampler(self, sampler): # If the input dataset size is smaller than the number of centers # remaining, choose the entire input dataset as centers. This can happen @@ -654,7 +757,10 @@ class _InitializeClustersOpFactory(object): with ops.control_dependencies([ check_ops.assert_positive(self._num_remaining), ]): - num_now_remaining = self._add_new_centers() + if self._initial_clusters == KMC2_INIT: + num_now_remaining = self._kmc2_multiple_centers() + else: + num_now_remaining = self._add_new_centers() return control_flow_ops.cond( math_ops.equal(num_now_remaining, 0), lambda: state_ops.assign(self._cluster_centers_initialized, True), diff --git a/tensorflow/contrib/factorization/python/ops/kmeans.py b/tensorflow/contrib/factorization/python/ops/kmeans.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5413fc3f2642443621b33d325e3d8c893fd6ac --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans.py @@ -0,0 +1,397 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A canned Estimator for k-means clustering.""" + +# TODO(ccolby): Move clustering_ops.py into this file and streamline the code. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import time + +from tensorflow.contrib.factorization.python.ops import clustering_ops +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import state_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import training_util + + +class _LossRelativeChangeHook(session_run_hook.SessionRunHook): + """Stops when the change in loss goes below a tolerance.""" + + def __init__(self, loss_tensor, tolerance): + """Creates a _LossRelativeChangeHook. + + Args: + loss_tensor: A scalar tensor of the loss value. + tolerance: A relative tolerance of loss change between iterations. + """ + self._loss_tensor = loss_tensor + self._tolerance = tolerance + self._prev_loss = None + + def before_run(self, run_context): + del run_context # unused + return session_run_hook.SessionRunArgs(self._loss_tensor) + + def after_run(self, run_context, run_values): + loss = run_values.results + assert loss is not None + if self._prev_loss: + relative_change = (abs(loss - self._prev_loss) / + (1 + abs(self._prev_loss))) + if relative_change < self._tolerance: + run_context.request_stop() + self._prev_loss = loss + + +class _InitializeClustersHook(session_run_hook.SessionRunHook): + """Initializes the cluster centers. + + The chief repeatedly invokes an initialization op until all cluster centers + are initialized. The workers wait for the initialization phase to complete. + """ + + def __init__(self, init_op, is_initialized_var, is_chief): + """Creates an _InitializeClustersHook. + + Args: + init_op: An op that, when run, will choose some initial cluster centers. + This op may need to be run multiple times to choose all the centers. + is_initialized_var: A boolean variable reporting whether all initial + centers have been chosen. + is_chief: A boolean specifying whether this task is the chief. + """ + self._init_op = init_op + self._is_initialized_var = is_initialized_var + self._is_chief = is_chief + + def after_create_session(self, session, coord): + del coord # unused + assert self._init_op.graph is ops.get_default_graph() + assert self._is_initialized_var.graph is self._init_op.graph + while True: + try: + if session.run(self._is_initialized_var): + break + elif self._is_chief: + session.run(self._init_op) + else: + time.sleep(1) + except RuntimeError as e: + logging.info(e) + + +def _parse_tensor_or_dict(features): + """Helper function to convert the input points into a usable format. + + Args: + features: The input points. + + Returns: + If `features` is a dict of `k` features, each of which is a vector of `n` + scalars, the return value is a Tensor of shape `(n, k)` representing `n` + input points, where the items in the `k` dimension are sorted + lexicographically by `features` key. If `features` is not a dict, it is + returned unmodified. + """ + if isinstance(features, dict): + keys = sorted(features.keys()) + with ops.colocate_with(features[keys[0]]): + features = array_ops.concat([features[k] for k in keys], axis=1) + return features + + +class _ModelFn(object): + """Model function for the estimator.""" + + def __init__(self, num_clusters, initial_clusters, distance_metric, + random_seed, use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance): + self._num_clusters = num_clusters + self._initial_clusters = initial_clusters + self._distance_metric = distance_metric + self._random_seed = random_seed + self._use_mini_batch = use_mini_batch + self._mini_batch_steps_per_iteration = mini_batch_steps_per_iteration + self._kmeans_plus_plus_num_retries = kmeans_plus_plus_num_retries + self._relative_tolerance = relative_tolerance + + def model_fn(self, features, mode, config): + """Model function for the estimator. + + Note that this does not take a `1abels` arg. This works, but `input_fn` must + return either `features` or, equivalently, `(features, None)`. + + Args: + features: The input points. See @{tf.estimator.Estimator}. + mode: See @{tf.estimator.Estimator}. + config: See @{tf.estimator.Estimator}. + + Returns: + A @{tf.estimator.EstimatorSpec} (see @{tf.estimator.Estimator}) specifying + this behavior: + * `train_op`: Execute one mini-batch or full-batch run of Lloyd's + algorithm. + * `loss`: The sum of the squared distances from each input point to its + closest center. + * `eval_metric_ops`: Maps `SCORE` to `loss`. + * `predictions`: Maps `ALL_DISTANCES` to the distance from each input + point to each cluster center; maps `CLUSTER_INDEX` to the index of + the closest cluster center for each input point. + """ + # input_points is a single Tensor. Therefore, the sharding functionality + # in clustering_ops is unused, and some of the values below are lists of a + # single item. + input_points = _parse_tensor_or_dict(features) + + # Let N = the number of input_points. + # all_distances: A list of one matrix of shape (N, num_clusters). Each value + # is the distance from an input point to a cluster center. + # model_predictions: A list of one vector of shape (N). Each value is the + # cluster id of an input point. + # losses: Similar to cluster_idx but provides the distance to the cluster + # center. + # is_initialized: scalar indicating whether the initial cluster centers + # have been chosen; see init_op. + # cluster_centers_var: a Variable containing the cluster centers. + # init_op: an op to choose the initial cluster centers. A single worker + # repeatedly executes init_op until is_initialized becomes True. + # training_op: an op that runs an iteration of training, either an entire + # Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers + # may execute this op, but only after is_initialized becomes True. + (all_distances, model_predictions, losses, is_initialized, init_op, + training_op) = clustering_ops.KMeans( + inputs=input_points, + num_clusters=self._num_clusters, + initial_clusters=self._initial_clusters, + distance_metric=self._distance_metric, + use_mini_batch=self._use_mini_batch, + mini_batch_steps_per_iteration=self._mini_batch_steps_per_iteration, + random_seed=self._random_seed, + kmeans_plus_plus_num_retries=self._kmeans_plus_plus_num_retries + ).training_graph() + + loss = math_ops.reduce_sum(losses) + summary.scalar('loss/raw', loss) + + incr_step = state_ops.assign_add(training_util.get_global_step(), 1) + training_op = control_flow_ops.with_dependencies([training_op, incr_step], + loss) + + training_hooks = [ + _InitializeClustersHook(init_op, is_initialized, config.is_chief) + ] + if self._relative_tolerance is not None: + training_hooks.append( + _LossRelativeChangeHook(loss, self._relative_tolerance)) + + return model_fn_lib.EstimatorSpec( + mode=mode, + predictions={ + KMeansClustering.ALL_DISTANCES: all_distances[0], + KMeansClustering.CLUSTER_INDEX: model_predictions[0], + }, + loss=loss, + train_op=training_op, + eval_metric_ops={KMeansClustering.SCORE: metrics.mean(loss)}, + training_hooks=training_hooks) + + +# TODO(agarwal,ands): support sharded input. +class KMeansClustering(estimator.Estimator): + """An Estimator for K-Means clustering.""" + + # Valid values for the distance_metric constructor argument. + SQUARED_EUCLIDEAN_DISTANCE = clustering_ops.SQUARED_EUCLIDEAN_DISTANCE + COSINE_DISTANCE = clustering_ops.COSINE_DISTANCE + + # Values for initial_clusters constructor argument. + RANDOM_INIT = clustering_ops.RANDOM_INIT + KMEANS_PLUS_PLUS_INIT = clustering_ops.KMEANS_PLUS_PLUS_INIT + + # Metric returned by evaluate(): The sum of the squared distances from each + # input point to its closest center. + SCORE = 'score' + + # Keys returned by predict(). + # ALL_DISTANCES: The distance from each input point to each cluster center. + # CLUSTER_INDEX: The index of the closest cluster center for each input point. + CLUSTER_INDEX = 'cluster_index' + ALL_DISTANCES = 'all_distances' + + def __init__(self, + num_clusters, + model_dir=None, + initial_clusters=RANDOM_INIT, + distance_metric=SQUARED_EUCLIDEAN_DISTANCE, + random_seed=0, + use_mini_batch=True, + mini_batch_steps_per_iteration=1, + kmeans_plus_plus_num_retries=2, + relative_tolerance=None, + config=None): + """Creates an Estimator for running KMeans training and inference. + + This Estimator implements the following variants of the K-means algorithm: + + If `use_mini_batch` is False, it runs standard full batch K-means. Each + training step runs a single iteration of K-Means and must process the full + input at once. To run in this mode, the `input_fn` passed to `train` must + return the entire input dataset. + + If `use_mini_batch` is True, it runs a generalization of the mini-batch + K-means algorithm. It runs multiple iterations, where each iteration is + composed of `mini_batch_steps_per_iteration` steps. Each training step + accumulates the contribution from one mini-batch into temporary storage. + Every `mini_batch_steps_per_iteration` steps, the cluster centers are + updated and the temporary storage cleared for the next iteration. Note + that: + * If `mini_batch_steps_per_iteration=1`, the algorithm reduces to the + standard K-means mini-batch algorithm. + * If `mini_batch_steps_per_iteration = num_inputs / batch_size`, the + algorithm becomes an asynchronous version of the full-batch algorithm. + However, there is no guarantee by this implementation that each input + is seen exactly once per iteration. Also, different updates are applied + asynchronously without locking. So this asynchronous version may not + behave exactly like a full-batch version. + + Args: + num_clusters: An integer tensor specifying the number of clusters. This + argument is ignored if `initial_clusters` is a tensor or numpy array. + model_dir: The directory to save the model results and log files. + initial_clusters: Specifies how the initial cluster centers are chosen. + One of the following: + * a tensor or numpy array with the initial cluster centers. + * a callable `f(inputs, k)` that selects and returns up to `k` centers + from an input batch. `f` is free to return any number of centers + from `0` to `k`. It will be invoked on successive input batches + as necessary until all `num_clusters` centers are chosen. + * `KMeansClustering.RANDOM_INIT`: Choose centers randomly from an input + batch. If the batch size is less than `num_clusters` then the + entire batch is chosen to be initial cluster centers and the + remaining centers are chosen from successive input batches. + * `KMeansClustering.KMEANS_PLUS_PLUS_INIT`: Use kmeans++ to choose + centers from the first input batch. If the batch size is less + than `num_clusters`, a TensorFlow runtime error occurs. + distance_metric: The distance metric used for clustering. One of: + * `KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`: Euclidean distance + between vectors `u` and `v` is defined as `||u - v||_2` which is + the square root of the sum of the absolute squares of the elements' + difference. + * `KMeansClustering.COSINE_DISTANCE`: Cosine distance between vectors + `u` and `v` is defined as `1 - (u . v) / (||u||_2 ||v||_2)`. + random_seed: Python integer. Seed for PRNG used to initialize centers. + use_mini_batch: A boolean specifying whether to use the mini-batch k-means + algorithm. See explanation above. + mini_batch_steps_per_iteration: The number of steps after which the + updated cluster centers are synced back to a master copy. Used only if + `use_mini_batch=True`. See explanation above. + kmeans_plus_plus_num_retries: For each point that is sampled during + kmeans++ initialization, this parameter specifies the number of + additional points to draw from the current distribution before selecting + the best. If a negative value is specified, a heuristic is used to + sample `O(log(num_to_sample))` additional points. Used only if + `initial_clusters=KMeansClustering.KMEANS_PLUS_PLUS_INIT`. + relative_tolerance: A relative tolerance of change in the loss between + iterations. Stops learning if the loss changes less than this amount. + This may not work correctly if `use_mini_batch=True`. + config: See @{tf.estimator.Estimator}. + + Raises: + ValueError: An invalid argument was passed to `initial_clusters` or + `distance_metric`. + """ + if isinstance(initial_clusters, str) and initial_clusters not in [ + KMeansClustering.RANDOM_INIT, KMeansClustering.KMEANS_PLUS_PLUS_INIT + ]: + raise ValueError( + "Unsupported initialization algorithm '%s'" % initial_clusters) + if distance_metric not in [ + KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + KMeansClustering.COSINE_DISTANCE + ]: + raise ValueError("Unsupported distance metric '%s'" % distance_metric) + super(KMeansClustering, self).__init__( + model_fn=_ModelFn( + num_clusters, initial_clusters, distance_metric, random_seed, + use_mini_batch, mini_batch_steps_per_iteration, + kmeans_plus_plus_num_retries, relative_tolerance).model_fn, + model_dir=model_dir, + config=config) + + def _predict_one_key(self, input_fn, predict_key): + for result in self.predict(input_fn=input_fn, predict_keys=[predict_key]): + yield result[predict_key] + + def predict_cluster_index(self, input_fn): + """Finds the index of the closest cluster center to each input point. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The index of the closest cluster center for each input point. + """ + for index in self._predict_one_key(input_fn, + KMeansClustering.CLUSTER_INDEX): + yield index + + def score(self, input_fn): + """Returns the sum of squared distances to nearest clusters. + + Note that this function is different from the corresponding one in sklearn + which returns the negative sum. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.evaluate}. Only one + batch is retrieved. + + Returns: + The sum of the squared distance from each point in the first batch of + inputs to its nearest cluster center. + """ + return self.evaluate(input_fn=input_fn, steps=1)[KMeansClustering.SCORE] + + def transform(self, input_fn): + """Transforms each input point to its distances to all cluster centers. + + Note that if `distance_metric=KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE`, + this + function returns the squared Euclidean distance while the corresponding + sklearn function returns the Euclidean distance. + + Args: + input_fn: Input points. See @{tf.estimator.Estimator.predict}. + + Yields: + The distances from each input point to each cluster center. + """ + for distances in self._predict_one_key(input_fn, + KMeansClustering.ALL_DISTANCES): + yield distances + + def cluster_centers(self): + """Returns the cluster centers.""" + return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME) diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4709d7942583f1406a3fa0ff3a078d0283872ea6 --- /dev/null +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -0,0 +1,575 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for KMeans.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import time + +import numpy as np +from sklearn.cluster import KMeans as SklearnKMeans + +# pylint: disable=g-import-not-at-top +from tensorflow.contrib.factorization.python.ops import kmeans as kmeans_lib +from tensorflow.python.estimator import run_config +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import benchmark +from tensorflow.python.platform import flags +from tensorflow.python.platform import test +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import queue_runner + +FLAGS = flags.FLAGS + + +def normalize(x): + return x / np.sqrt(np.sum(x * x, axis=-1, keepdims=True)) + + +def cosine_similarity(x, y): + return np.dot(normalize(x), np.transpose(normalize(y))) + + +def make_random_centers(num_centers, num_dims, center_norm=500): + return np.round( + np.random.rand(num_centers, num_dims).astype(np.float32) * center_norm) + + +def make_random_points(centers, num_points, max_offset=20): + num_centers, num_dims = centers.shape + assignments = np.random.choice(num_centers, num_points) + offsets = np.round( + np.random.randn(num_points, num_dims).astype(np.float32) * max_offset) + return (centers[assignments] + offsets, assignments, np.add.reduce( + offsets * offsets, 1)) + + +class KMeansTestBase(test.TestCase): + + def input_fn(self, + batch_size=None, + points=None, + randomize=None, + num_epochs=None): + """Returns an input_fn that randomly selects batches from given points.""" + batch_size = batch_size or self.batch_size + points = points if points is not None else self.points + num_points = points.shape[0] + if randomize is None: + randomize = (self.use_mini_batch and + self.mini_batch_steps_per_iteration <= 1) + + def _fn(): + x = constant_op.constant(points) + if batch_size == num_points: + return input_lib.limit_epochs(x, num_epochs=num_epochs), None + if randomize: + indices = random_ops.random_uniform( + constant_op.constant([batch_size]), + minval=0, + maxval=num_points - 1, + dtype=dtypes.int32, + seed=10) + else: + # We need to cycle through the indices sequentially. We create a queue + # to maintain the list of indices. + q = data_flow_ops.FIFOQueue(num_points, dtypes.int32, ()) + + # Conditionally initialize the Queue. + def _init_q(): + with ops.control_dependencies( + [q.enqueue_many(math_ops.range(num_points))]): + return control_flow_ops.no_op() + + init_q = control_flow_ops.cond(q.size() <= 0, _init_q, + control_flow_ops.no_op) + with ops.control_dependencies([init_q]): + offsets = q.dequeue_many(batch_size) + with ops.control_dependencies([q.enqueue_many(offsets)]): + indices = array_ops.identity(offsets) + batch = array_ops.gather(x, indices) + return (input_lib.limit_epochs(batch, num_epochs=num_epochs), None) + + return _fn + + @staticmethod + def config(tf_random_seed): + return run_config.RunConfig().replace(tf_random_seed=tf_random_seed) + + @property + def initial_clusters(self): + return kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT + + @property + def batch_size(self): + return self.num_points + + @property + def use_mini_batch(self): + return False + + @property + def mini_batch_steps_per_iteration(self): + return 1 + + +class KMeansTest(KMeansTestBase): + + def setUp(self): + np.random.seed(3) + self.num_centers = 5 + self.num_dims = 2 + self.num_points = 1000 + self.true_centers = make_random_centers(self.num_centers, self.num_dims) + self.points, _, self.scores = make_random_points(self.true_centers, + self.num_points) + self.true_score = np.add.reduce(self.scores) + + def _kmeans(self, relative_tolerance=None): + return kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + random_seed=24, + relative_tolerance=relative_tolerance) + + def test_clusters(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(list(clusters.shape), [self.num_centers, self.num_dims]) + + def test_fit(self): + kmeans = self._kmeans() + kmeans.train(input_fn=self.input_fn(), steps=1) + score1 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + steps = 10 * self.num_points // self.batch_size + kmeans.train(input_fn=self.input_fn(), steps=steps) + score2 = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertTrue(score1 > score2) + self.assertNear(self.true_score, score2, self.true_score * 0.05) + + def test_monitor(self): + if self.use_mini_batch: + # We don't test for use_mini_batch case since the loss value can be noisy. + return + kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(14), + random_seed=12, + relative_tolerance=1e-4) + + kmeans.train( + input_fn=self.input_fn(), + # Force it to train until the relative tolerance monitor stops it. + steps=None) + score = kmeans.score(input_fn=self.input_fn(batch_size=self.num_points)) + self.assertNear(self.true_score, score, self.true_score * 0.01) + + def test_infer(self): + kmeans = self._kmeans() + # Make a call to fit to initialize the cluster centers. + max_steps = 1 + kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + clusters = kmeans.cluster_centers() + + # Make a small test set + num_points = 10 + points, true_assignments, true_offsets = make_random_points( + clusters, num_points) + input_fn = self.input_fn(batch_size=num_points, points=points, num_epochs=1) + # Test predict + assignments = list(kmeans.predict_cluster_index(input_fn)) + self.assertAllEqual(assignments, true_assignments) + + # Test score + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertNear(score, np.sum(true_offsets), 0.01 * score) + + # Test transform + transform = list(kmeans.transform(input_fn)) + true_transform = np.maximum( + 0, + np.sum(np.square(points), axis=1, keepdims=True) - + 2 * np.dot(points, np.transpose(clusters)) + np.transpose( + np.sum(np.square(clusters), axis=1, keepdims=True))) + self.assertAllClose(transform, true_transform, rtol=0.05, atol=10) + + +class KMeansTestMultiStageInit(KMeansTestBase): + + def test_random(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_just_right(self): + points = np.array([[1, 2]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + kmeans.train( + input_fn=self.input_fn(batch_size=1, points=points, randomize=False), + steps=1) + clusters = kmeans.cluster_centers() + self.assertAllEqual(points, clusters) + + def test_kmeans_plus_plus_batch_too_small(self): + points = np.array( + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 0]], dtype=np.float32) + kmeans = kmeans_lib.KMeansClustering( + num_clusters=points.shape[0], + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + distance_metric=kmeans_lib.KMeansClustering.SQUARED_EUCLIDEAN_DISTANCE, + use_mini_batch=True, + mini_batch_steps_per_iteration=100, + random_seed=24, + relative_tolerance=None) + with self.assertRaisesOpError(AssertionError): + kmeans.train( + input_fn=self.input_fn(batch_size=4, points=points, randomize=False), + steps=1) + + +class MiniBatchKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansTest(KMeansTest): + + @property + def batch_size(self): + return 50 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansCosineDistanceTest(KMeansTestBase): + + def setUp(self): + self.points = np.array( + [[2.5, 0.1], [2, 0.2], [3, 0.1], [4, 0.2], [0.1, 2.5], [0.2, 2], + [0.1, 3], [0.2, 4]], + dtype=np.float32) + self.num_points = self.points.shape[0] + self.true_centers = np.array( + [ + normalize( + np.mean(normalize(self.points)[0:4, :], axis=0, + keepdims=True))[0], + normalize( + np.mean(normalize(self.points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + self.true_assignments = np.array([0] * 4 + [1] * 4) + self.true_score = len(self.points) - np.tensordot( + normalize(self.points), self.true_centers[self.true_assignments]) + + self.num_centers = 2 + self.kmeans = kmeans_lib.KMeansClustering( + self.num_centers, + initial_clusters=kmeans_lib.KMeansClustering.RANDOM_INIT, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + + def test_fit(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + + def test_transform(self): + self.kmeans.train(input_fn=self.input_fn(), steps=10) + centers = normalize(self.kmeans.cluster_centers()) + true_transform = 1 - cosine_similarity(self.points, centers) + transform = list( + self.kmeans.transform( + input_fn=self.input_fn(batch_size=self.num_points, num_epochs=1))) + self.assertAllClose(transform, true_transform, atol=1e-3) + + def test_predict(self): + max_steps = 10 * self.num_points // self.batch_size + self.kmeans.train(input_fn=self.input_fn(), max_steps=max_steps) + centers = normalize(self.kmeans.cluster_centers()) + + assignments = list( + self.kmeans.predict_cluster_index( + input_fn=self.input_fn(num_epochs=1, batch_size=self.num_points))) + self.assertAllClose( + centers[assignments], + self.true_centers[self.true_assignments], + atol=1e-2) + + centers = centers[centers[:, 0].argsort()] + true_centers = self.true_centers[self.true_centers[:, 0].argsort()] + self.assertAllClose(centers, true_centers, atol=0.04) + score = self.kmeans.score( + input_fn=self.input_fn(batch_size=self.num_points)) + self.assertAllClose(score, self.true_score, atol=1e-2) + + def test_predict_kmeans_plus_plus(self): + # Most points are concetrated near one center. KMeans++ is likely to find + # the less populated centers. + points = np.array( + [[2.5, 3.5], [2.5, 3.5], [-2, 3], [-2, 3], [-3, -3], [-3.1, -3.2], + [-2.8, -3.], [-2.9, -3.1], [-3., -3.1], [-3., -3.1], [-3.2, -3.], + [-3., -3.]], + dtype=np.float32) + true_centers = np.array( + [ + normalize( + np.mean(normalize(points)[0:2, :], axis=0, keepdims=True))[0], + normalize( + np.mean(normalize(points)[2:4, :], axis=0, keepdims=True))[0], + normalize(np.mean(normalize(points)[4:, :], axis=0, + keepdims=True))[0] + ], + dtype=np.float32) + true_assignments = [0] * 2 + [1] * 2 + [2] * 8 + true_score = len(points) - np.tensordot( + normalize(points), true_centers[true_assignments]) + + kmeans = kmeans_lib.KMeansClustering( + 3, + initial_clusters=self.initial_clusters, + distance_metric=kmeans_lib.KMeansClustering.COSINE_DISTANCE, + use_mini_batch=self.use_mini_batch, + mini_batch_steps_per_iteration=self.mini_batch_steps_per_iteration, + config=self.config(3)) + kmeans.train( + input_fn=lambda: (constant_op.constant(points), None), steps=30) + + centers = normalize(kmeans.cluster_centers()) + self.assertAllClose( + sorted(centers.tolist()), sorted(true_centers.tolist()), atol=1e-2) + + def _input_fn(): + return (input_lib.limit_epochs( + constant_op.constant(points), num_epochs=1), None) + + assignments = list(kmeans.predict_cluster_index(input_fn=_input_fn)) + self.assertAllClose( + centers[assignments], true_centers[true_assignments], atol=1e-2) + + score = kmeans.score(input_fn=lambda: (constant_op.constant(points), None)) + self.assertAllClose(score, true_score, atol=1e-2) + + +class MiniBatchKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + +class FullBatchAsyncKMeansCosineTest(KMeansCosineDistanceTest): + + @property + def batch_size(self): + return 2 + + @property + def use_mini_batch(self): + return True + + @property + def mini_batch_steps_per_iteration(self): + return self.num_points // self.batch_size + + +class KMeansBenchmark(benchmark.Benchmark): + """Base class for benchmarks.""" + + def SetUp(self, + dimension=50, + num_clusters=50, + points_per_cluster=10000, + center_norm=500, + cluster_width=20): + np.random.seed(123456) + self.num_clusters = num_clusters + self.num_points = num_clusters * points_per_cluster + self.centers = make_random_centers( + self.num_clusters, dimension, center_norm=center_norm) + self.points, _, scores = make_random_points( + self.centers, self.num_points, max_offset=cluster_width) + self.score = float(np.sum(scores)) + + def _report(self, num_iters, start, end, scores): + print(scores) + self.report_benchmark( + iters=num_iters, + wall_time=(end - start) / num_iters, + extras={'true_sum_squared_distances': self.score, + 'fit_scores': scores}) + + def _fit(self, num_iters=10): + pass + + def benchmark_01_2dim_5center_500point(self): + self.SetUp(dimension=2, num_clusters=5, points_per_cluster=100) + self._fit() + + def benchmark_02_20dim_20center_10kpoint(self): + self.SetUp(dimension=20, num_clusters=20, points_per_cluster=500) + self._fit() + + def benchmark_03_100dim_50center_50kpoint(self): + self.SetUp(dimension=100, num_clusters=50, points_per_cluster=1000) + self._fit() + + def benchmark_03_100dim_50center_50kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=50, + points_per_cluster=1000, + cluster_width=250) + self._fit() + + def benchmark_04_100dim_500center_500kpoint(self): + self.SetUp(dimension=100, num_clusters=500, points_per_cluster=1000) + self._fit(num_iters=4) + + def benchmark_05_100dim_500center_500kpoint_unseparated(self): + self.SetUp( + dimension=100, + num_clusters=500, + points_per_cluster=1000, + cluster_width=250) + self._fit(num_iters=4) + + +class TensorflowKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting tensorflow KMeans: %d' % i) + tf_kmeans = kmeans_lib.KMeansClustering( + self.num_clusters, + initial_clusters=kmeans_lib.KMeansClustering.KMEANS_PLUS_PLUS_INIT, + kmeans_plus_plus_num_retries=int(math.log(self.num_clusters) + 2), + random_seed=i * 42, + relative_tolerance=1e-6, + config=self.config(3)) + tf_kmeans.train( + input_fn=lambda: (constant_op.constant(self.points), None), steps=50) + _ = tf_kmeans.cluster_centers() + scores.append( + tf_kmeans.score( + input_fn=lambda: (constant_op.constant(self.points), None))) + self._report(num_iters, start, time.time(), scores) + + +class SklearnKMeansBenchmark(KMeansBenchmark): + + def _fit(self, num_iters=10): + scores = [] + start = time.time() + for i in range(num_iters): + print('Starting sklearn KMeans: %d' % i) + sklearn_kmeans = SklearnKMeans( + n_clusters=self.num_clusters, + init='k-means++', + max_iter=50, + n_init=1, + tol=1e-4, + random_state=i * 42) + sklearn_kmeans.train(self.points) + scores.append(sklearn_kmeans.inertia_) + self._report(num_iters, start, time.time(), scores) + + +class KMeansTestQueues(test.TestCase): + + def input_fn(self): + + def _fn(): + queue = data_flow_ops.FIFOQueue( + capacity=10, dtypes=dtypes.float32, shapes=[10, 3]) + enqueue_op = queue.enqueue(array_ops.zeros([10, 3], dtype=dtypes.float32)) + queue_runner.add_queue_runner( + queue_runner.QueueRunner(queue, [enqueue_op])) + return queue.dequeue(), None + + return _fn + + # This test makes sure that there are no deadlocks when using a QueueRunner. + # Note that since cluster initialization is dependendent on inputs, if input + # is generated using a QueueRunner, one has to make sure that these runners + # are started before the initialization. + def test_queues(self): + kmeans = kmeans_lib.KMeansClustering(5) + kmeans.train(input_fn=self.input_fn(), steps=1) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/factorization/python/ops/wals.py b/tensorflow/contrib/factorization/python/ops/wals.py index 3e3ee5fa57f1356db98a17f9e17e60f01d85d3b9..3976395d78e9188dd56d5b3b32fa8a3daf43c37d 100644 --- a/tensorflow/contrib/factorization/python/ops/wals.py +++ b/tensorflow/contrib/factorization/python/ops/wals.py @@ -26,7 +26,6 @@ from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_math_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope @@ -38,31 +37,30 @@ from tensorflow.python.training import session_run_hook class _SweepHook(session_run_hook.SessionRunHook): """Keeps track of row/col sweeps, and runs prep ops before each sweep.""" - def __init__(self, is_row_sweep_var, train_op, num_rows, num_cols, - processed_row_indices, processed_col_indices, row_prep_ops, - col_prep_ops, cache_init_ops, completed_sweeps_var): + def __init__(self, is_row_sweep_var, train_ops, num_rows, num_cols, + input_row_indices, input_col_indices, row_prep_ops, + col_prep_ops, init_op, completed_sweeps_var): """Initializes SweepHook. Args: is_row_sweep_var: A Boolean tf.Variable, determines whether we are currently doing a row or column sweep. It is updated by the hook. - train_op: An op. All the ops created by the hook will have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this hook will have + control dependencies on `train_ops`. num_rows: int, the total number of rows to be processed. num_cols: int, the total number of columns to be processed. - processed_row_indices: A Tensor of type int64. The indices of the input - rows that are processed during the current sweep. All elements of - processed_row_indices must be in [0, num_rows). - processed_col_indices: A Tensor of type int64. The indices of the input + input_row_indices: A Tensor of type int64. The indices of the input rows + that are processed during the current sweep. All elements of + `input_row_indices` must be in [0, num_rows). + input_col_indices: A Tensor of type int64. The indices of the input columns that are processed during the current sweep. All elements of - processed_col_indices must be in [0, num_cols). + `input_col_indices` must be in [0, num_cols). row_prep_ops: list of ops, to be run before the beginning of each row sweep, in the given order. col_prep_ops: list of ops, to be run before the beginning of each column sweep, in the given order. - cache_init_ops: list of ops, to be run once before training, in the given - order. These are typically local initialization ops (such as cache - initialization). + init_op: op to be run once before training. This is typically a local + initialization op (such as cache initialization). completed_sweeps_var: An integer tf.Variable, indicates the number of completed sweeps. It is updated by the hook. """ @@ -70,55 +68,45 @@ class _SweepHook(session_run_hook.SessionRunHook): self._num_cols = num_cols self._row_prep_ops = row_prep_ops self._col_prep_ops = col_prep_ops - self._cache_init_ops = cache_init_ops + self._init_op = init_op self._is_row_sweep_var = is_row_sweep_var self._completed_sweeps_var = completed_sweeps_var - # Boolean variable that determines whether the cache_init_ops have been run. + # Boolean variable that determines whether the init_ops have been run. self._is_initialized = False - # Boolean variable that is set to True when a sweep is completed. - # Used to run the prep_ops at the beginning of a sweep, in before_run(). - self._is_sweep_done = False - # Ops to run jointly with train_op, responsible for updating - # _is_row_sweep_var and incrementing the global_step and completed_sweeps - # counters. They have control_dependencies on train_op. - self._fetches = self._create_switch_ops(processed_row_indices, - processed_col_indices, train_op) - - def _create_switch_ops(self, processed_row_indices, processed_col_indices, - train_op): + # Ops to run jointly with train_ops, responsible for updating + # `is_row_sweep_var` and incrementing the `global_step` and + # `completed_sweeps` counters. + self._update_op, self._is_sweep_done_var, self._switch_op = ( + self._create_hook_ops(input_row_indices, input_col_indices, train_ops)) + + def _create_hook_ops(self, input_row_indices, input_col_indices, train_ops): """Creates ops to update is_row_sweep_var, global_step and completed_sweeps. - Creates two boolean tensors processed_rows and processed_cols, which keep - track of which rows/cols have been processed during the current sweep. + Creates two boolean tensors `processed_rows` and `processed_cols`, which + keep track of which rows/cols have been processed during the current sweep. Returns ops that should be run after each row / col update. - - When is_row_sweep_var is True, it sets - processed_rows[processed_row_indices] to True. - - When is_row_sweep_var is False, it sets - processed_cols[processed_col_indices] to True . - When all rows or all cols have been processed, negates is_row_sweep_var, - increments the completed_sweeps counter, and resets processed_rows and - processed_cols to False. - All of the ops created by this function have control_dependencies on - train_op. + - When `self._is_row_sweep_var` is True, it sets + processed_rows[input_row_indices] to True. + - When `self._is_row_sweep_var` is False, it sets + processed_cols[input_col_indices] to True. Args: - processed_row_indices: A Tensor. The indices of the input rows that are + input_row_indices: A Tensor. The indices of the input rows that are processed during the current sweep. - processed_col_indices: A Tensor. The indices of the input columns that + input_col_indices: A Tensor. The indices of the input columns that are processed during the current sweep. - train_op: An op. All the ops created by this function have - control_dependencies on train_op. + train_ops: A list of ops. The ops created by this function have control + dependencies on `train_ops`. + Returns: - A list consisting of: - is_sweep_done: A Boolean tensor, determines whether the sweep is done, - i.e. all rows (during a row sweep) or all columns (during a column - sweep) have been processed. - switch_ops: An op that updates is_row_sweep_var when is_sweep_done is - True. Has control_dependencies on train_op. - incr_ops: An op that increments the global_step and completed_sweeps - counters. Has control_dependenciens on switch_ops. + A tuple consisting of: + update_op: An op to be run jointly with training. It updates the state + and increments counters (global step and completed sweeps). + is_sweep_done_var: A Boolean tf.Variable, specifies whether the sweep is + done, i.e. all rows (during a row sweep) or all columns (during a + column sweep) have been processed. + switch_op: An op to be run in `self.before_run` when the sweep is done. """ - processed_rows_init = array_ops.fill(dims=[self._num_rows], value=False) with ops.colocate_with(processed_rows_init): processed_rows = variable_scope.variable( @@ -133,97 +121,72 @@ class _SweepHook(session_run_hook.SessionRunHook): collections=[ops.GraphKeys.GLOBAL_VARIABLES], trainable=False, name="sweep_hook_processed_cols") - # After running the train_op, update processed_rows or processed_cols - # tensors, depending on whether we are currently doing a row or a col sweep - with ops.control_dependencies([train_op]): - - def get_row_update_op(): - with ops.colocate_with(processed_rows): - return state_ops.scatter_update(processed_rows, processed_row_indices, - array_ops.ones_like( - processed_row_indices, - dtype=dtypes.bool)) - - def get_col_update_op(): - with ops.colocate_with(processed_cols): - return state_ops.scatter_update(processed_cols, processed_col_indices, - array_ops.ones_like( - processed_col_indices, - dtype=dtypes.bool)) - - update_processed_op = control_flow_ops.cond( - self._is_row_sweep_var, get_row_update_op, get_col_update_op) - - # After update_processed_op, check whether we have completed a sweep. - # If this is the case, flip the is_row_sweep_var and reset processed_rows - # and processed_cols tensors. - with ops.control_dependencies([update_processed_op]): - - def get_switch_op(): - return state_ops.assign( - self._is_row_sweep_var, - gen_math_ops.logical_not(self._is_row_sweep_var)).op - - def get_reset_op(): - return control_flow_ops.group( - state_ops.assign(processed_rows, processed_rows_init).op, - state_ops.assign(processed_cols, processed_cols_init).op) - - is_sweep_done = control_flow_ops.cond( + switch_ops = control_flow_ops.group( + state_ops.assign( self._is_row_sweep_var, - lambda: math_ops.reduce_all(processed_rows), - lambda: math_ops.reduce_all(processed_cols), - name="sweep_hook_is_sweep_done") - switch_op = control_flow_ops.cond( - is_sweep_done, - get_switch_op, - control_flow_ops.no_op, - name="sweep_hook_switch_op") - reset_op = control_flow_ops.cond( - is_sweep_done, - get_reset_op, - control_flow_ops.no_op, - name="sweep_hook_reset_op") - switch_ops = control_flow_ops.group( - switch_op, reset_op, name="sweep_hook_switch_ops") - - with ops.control_dependencies([switch_ops]): - # Op to increment the completed_sweeps counter. - completed_sweeps_incr_op = control_flow_ops.cond( - is_sweep_done, - lambda: state_ops.assign_add(self._completed_sweeps_var, 1).op, - control_flow_ops.no_op, - name="completed_sweeps_incr") - - # Op to increment the global_step counter. - global_step = framework_variables.get_global_step() - if global_step is not None: - global_step_incr_op = state_ops.assign_add( - global_step, 1, name="global_step_incr").op - else: - global_step_incr_op = control_flow_ops.no_op( - name="global_step_incr") - - incr_ops = control_flow_ops.group( - completed_sweeps_incr_op, - global_step_incr_op, - name="counter_incr_ops") - - return [is_sweep_done, switch_ops, incr_ops] + math_ops.logical_not(self._is_row_sweep_var)), + state_ops.assign(processed_rows, processed_rows_init), + state_ops.assign(processed_cols, processed_cols_init)) + is_sweep_done_var = variable_scope.variable( + False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES], + trainable=False, + name="is_sweep_done") + + # After running the `train_ops`, updates `processed_rows` or + # `processed_cols` tensors, depending on whether this is a row or col sweep. + with ops.control_dependencies(train_ops): + with ops.colocate_with(processed_rows): + update_processed_rows = state_ops.scatter_update( + processed_rows, + input_row_indices, + math_ops.logical_and( + self._is_row_sweep_var, + array_ops.ones_like(input_row_indices, dtype=dtypes.bool))) + with ops.colocate_with(processed_cols): + update_processed_cols = state_ops.scatter_update( + processed_cols, + input_col_indices, + math_ops.logical_and( + math_ops.logical_not(self._is_row_sweep_var), + array_ops.ones_like(input_col_indices, dtype=dtypes.bool))) + update_processed_op = control_flow_ops.group( + update_processed_rows, update_processed_cols) - def begin(self): - pass + with ops.control_dependencies([update_processed_op]): + is_sweep_done = math_ops.logical_or( + math_ops.reduce_all(processed_rows), + math_ops.reduce_all(processed_cols)) + # Increments global step. + global_step = framework_variables.get_global_step() + if global_step is not None: + global_step_incr_op = state_ops.assign_add( + global_step, 1, name="global_step_incr").op + else: + global_step_incr_op = control_flow_ops.no_op() + # Increments completed sweeps. + completed_sweeps_incr_op = state_ops.assign_add( + self._completed_sweeps_var, + math_ops.cast(is_sweep_done, dtypes.int32), + use_locking=True).op + update_ops = control_flow_ops.group( + global_step_incr_op, + completed_sweeps_incr_op, + state_ops.assign(is_sweep_done_var, is_sweep_done)) + + return update_ops, is_sweep_done_var, switch_ops def before_run(self, run_context): """Runs the appropriate prep ops, and requests running update ops.""" - # Run the appropriate cache_init and prep ops + # Runs the appropriate init ops and prep ops. sess = run_context.session + is_sweep_done = sess.run(self._is_sweep_done_var) if not self._is_initialized: - logging.info("SweepHook running cache init ops.") - for init_op in self._cache_init_ops: - sess.run(init_op) - - if self._is_sweep_done or not self._is_initialized: + logging.info("SweepHook running cache init op.") + sess.run(self._init_op) + if is_sweep_done: + sess.run(self._switch_op) + if is_sweep_done or not self._is_initialized: logging.info("SweepHook running sweep prep ops.") row_sweep = sess.run(self._is_row_sweep_var) prep_ops = self._row_prep_ops if row_sweep else self._col_prep_ops @@ -232,13 +195,12 @@ class _SweepHook(session_run_hook.SessionRunHook): self._is_initialized = True - # Request running the switch_ops and the incr_ops - logging.info("Partial fit starting.") - return session_run_hook.SessionRunArgs(fetches=self._fetches) + # Requests running `self._update_op` jointly with the training op. + logging.info("Next fit step starting.") + return session_run_hook.SessionRunArgs(fetches=[self._update_op]) def after_run(self, run_context, run_values): - self._is_sweep_done = run_values.results[0] - logging.info("Partial fit done.") + logging.info("Fit step done.") class _StopAtSweepHook(session_run_hook.SessionRunHook): @@ -360,19 +322,19 @@ def _wals_factorization_model_function(features, labels, mode, params): col_prep_ops = [ model.col_update_prep_gramian_op, model.initialize_col_update_op ] - cache_init_ops = [model.worker_init] + init_ops = [model.worker_init] sweep_hook = _SweepHook( is_row_sweep_var, - train_op, + [train_op, loss], params["num_rows"], params["num_cols"], input_row_indices, input_col_indices, row_prep_ops, col_prep_ops, - cache_init_ops, - completed_sweeps_var,) + init_ops, + completed_sweeps_var) training_hooks = [sweep_hook] if max_sweeps is not None: training_hooks.append(_StopAtSweepHook(max_sweeps)) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index b5c1bb1151e78a8f19d3c91b57ef3bfd6152893d..8bd72b7025aad80e387171b93b9b264da3ed0f66 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -357,7 +357,7 @@ class WALSMatrixFactorizationTest(test.TestCase): self.assertNear( loss, true_loss, err=.001, - msg="""After row update, eval loss = {}, does not match the true + msg="""After col update, eval loss = {}, does not match the true loss = {}.""".format(loss, true_loss)) @@ -442,7 +442,7 @@ class SweepHookTest(test.TestCase): completed_sweeps_var = variables.Variable(0) sweep_hook = wals_lib._SweepHook( is_row_sweep_var, - self._train_op, + [self._train_op], self._num_rows, self._num_cols, self._input_row_indices_ph, @@ -465,11 +465,9 @@ class SweepHookTest(test.TestCase): 'False.') # Row sweep completed. mon_sess.run(self._train_op, ind_feed([3, 4], [0, 1, 2, 3, 4, 5, 6])) - self.assertFalse(sess.run(is_row_sweep_var), - msg='Row sweep is complete but is_row_sweep is True.') self.assertTrue(sess.run(completed_sweeps_var) == 1, msg='Completed sweeps should be equal to 1.') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') # Col init ops should run. Col sweep not completed. mon_sess.run(self._train_op, ind_feed([], [0, 1, 2, 3, 4])) @@ -478,13 +476,11 @@ class SweepHookTest(test.TestCase): self.assertFalse(sess.run(is_row_sweep_var), msg='Col sweep is not complete but is_row_sweep is ' 'True.') - self.assertFalse(sweep_hook._is_sweep_done, + self.assertFalse(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is not complete but is_sweep_done is True.') # Col sweep completed. mon_sess.run(self._train_op, ind_feed([], [4, 5, 6])) - self.assertTrue(sess.run(is_row_sweep_var), - msg='Col sweep is complete but is_row_sweep is False') - self.assertTrue(sweep_hook._is_sweep_done, + self.assertTrue(sess.run(sweep_hook._is_sweep_done_var), msg='Sweep is complete but is_sweep_done is False.') self.assertTrue(sess.run(completed_sweeps_var) == 2, msg='Completed sweeps should be equal to 2.') diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD index e205d92fbe2f45cafde76f79643eb85b6876d48b..7a5a4cb8c9499b950a3ad89be710e48474d5791e 100644 --- a/tensorflow/contrib/ffmpeg/BUILD +++ b/tensorflow/contrib/ffmpeg/BUILD @@ -89,6 +89,7 @@ tf_py_test( "@six_archive//:six", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:platform", ], data = [ @@ -105,6 +106,7 @@ tf_py_test( "@six_archive//:six", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", "//tensorflow/python:platform", ], data = [ diff --git a/tensorflow/contrib/ffmpeg/default/BUILD b/tensorflow/contrib/ffmpeg/default/BUILD index 05fc658d80f26b00f775211cf89f55ce18a4502d..949ae9ad9e4b045ee1b5cc82d49c0e7468c2005d 100644 --- a/tensorflow/contrib/ffmpeg/default/BUILD +++ b/tensorflow/contrib/ffmpeg/default/BUILD @@ -23,6 +23,18 @@ cc_library( ], ) +tf_cc_test( + name = "ffmpeg_lib_utility_test", + srcs = ["ffmpeg_lib_utility_test.cc"], + deps = [ + ":ffmpeg_lib", + "//tensorflow/core:framework_internal", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + tf_cc_test( name = "ffmpeg_lib_installed_test", srcs = ["ffmpeg_lib_test.cc"], diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc index 888f5c38a27d4f98fe2a68d8d4236d580b16e54d..545a4386d043af604a747b8b5a8103101812b177 100644 --- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc @@ -198,6 +198,14 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count, return data; } +// Returns a unique number every time it is called. +int64 UniqueId() { + static mutex mu(LINKER_INITIALIZED); + static int64 id = 0; + mutex_lock l(mu); + return ++id; +} + } // namespace string GetTempFilename(const string& extension) { @@ -208,7 +216,19 @@ string GetTempFilename(const string& extension) { } struct stat statbuf; if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) { - return io::JoinPath(dir, StrCat("tmp_file_", getpid(), ".", extension)); + // UniqueId is added here because mkstemps is not as thread safe as it + // looks. https://github.com/tensorflow/tensorflow/issues/5804 shows + // the problem. + string tmp_filepath = io::JoinPath( + dir, + StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension)); + int fd = mkstemps(&tmp_filepath[0], extension.length() + 1); + if (fd < 0) { + LOG(FATAL) << "Failed to create temp file."; + } else { + close(fd); + return tmp_filepath; + } } } LOG(FATAL) << "No temp directory found."; diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7176f3b550679555d5ab3b70f2b360a90eaee253 --- /dev/null +++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc @@ -0,0 +1,80 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h" + +#include +#include +#include +#include + +#include "tensorflow/core/lib/core/threadpool.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/mutex.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace ffmpeg { +namespace { + +TEST(FfmpegLibTest, TestTempDirectoryThreading) { + // Testing a fix for a bug that allowed different threads to create + // conflicting temp files. + // See github.com/tensorflow/tensorflow/issues/5804 for details. + const int32 kNumThreads = 10; + const int32 kNumWorkItems = 10000; + static constexpr size_t kStringsPerItem = 100; + Env* environment = Env::Default(); + thread::ThreadPool pool(environment, "test", kNumThreads); + + mutex mu; + std::vector temp_filenames; + temp_filenames.reserve(kNumWorkItems * kStringsPerItem); + + // Queue a large number of work items for the threads to process. Each work + // item creates a temp file and then deletes it. + for (int i = 0; i < kNumWorkItems; ++i) { + pool.Schedule([&mu, &temp_filenames, environment]() { + std::array buffer; + for (int32 j = 0; j < kStringsPerItem; ++j) { + buffer[j] = GetTempFilename("mp3"); + TF_QCHECK_OK(environment->DeleteFile(buffer[j])); + } + mutex_lock l(mu); + for (const auto& fn : buffer) { + temp_filenames.push_back(fn); + } + }); + } + + // Wait until all work items are complete. + while (true) { + mutex_lock l(mu); + if (temp_filenames.size() == kNumWorkItems * kStringsPerItem) { + break; + } + } + + // Check that no duplicates are created. + std::set unique_filenames; + mutex_lock l(mu); + for (const auto& fn : temp_filenames) { + ASSERT_TRUE(unique_filenames.insert(fn).second); + } +} + +} // namespace +} // namespace ffmpeg +} // namespace tensorflow diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD index 6b0599ddd2def8dd698a1bd152b5be926c6ddf3e..e8dad886a1409babdf4ea47b9cd05def1f1ce25e 100644 --- a/tensorflow/contrib/framework/BUILD +++ b/tensorflow/contrib/framework/BUILD @@ -10,9 +10,8 @@ package(default_visibility = [ "//tensorflow:__subpackages__", ]) -load("//tensorflow:tensorflow.bzl", "cuda_py_test") -load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library") load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.bzl", "tf_gen_op_libs") @@ -25,13 +24,16 @@ tf_custom_op_py_library( "python/framework/__init__.py", "python/framework/checkpoint_utils.py", "python/framework/experimental.py", + "python/framework/graph_util.py", "python/framework/tensor_util.py", "python/ops/__init__.py", + "python/ops/accumulate_n_v2.py", "python/ops/arg_scope.py", "python/ops/audio_ops.py", "python/ops/checkpoint_ops.py", "python/ops/ops.py", "python/ops/prettyprint_ops.py", + "python/ops/sort_ops.py", "python/ops/variables.py", ], dso = [ @@ -47,6 +49,7 @@ tf_custom_op_py_library( "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:audio_ops_gen", + "//tensorflow/python:check_ops", "//tensorflow/python:checkpoint_ops_gen", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework", @@ -56,13 +59,17 @@ tf_custom_op_py_library( "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", "//tensorflow/python:state_ops_gen", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", "//third_party/py/numpy", "@six_archive//:six", ], @@ -149,6 +156,43 @@ py_test( ], ) +py_test( + name = "accumulate_n_v2_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "accumulate_n_v2_eager_test", + size = "small", + srcs = ["python/ops/accumulate_n_v2_eager_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python/eager:backprop", + "//tensorflow/python/eager:context", + "//tensorflow/python/eager:tape", + "//third_party/py/numpy", + ], +) + py_test( name = "ops_test", size = "small", @@ -189,6 +233,17 @@ py_test( ], ) +py_test( + name = "graph_util_test", + srcs = ["python/framework/graph_util_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:platform", + ], +) + py_test( name = "tensor_util_test", srcs = ["python/framework/tensor_util_test.py"], @@ -214,7 +269,6 @@ py_test( deps = [ ":framework_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", @@ -222,6 +276,7 @@ py_test( "//tensorflow/python:nn_ops", "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", @@ -254,7 +309,6 @@ py_test( "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:partitioned_variables", @@ -266,6 +320,20 @@ py_test( ], ) +py_test( + name = "sort_ops_test", + size = "medium", + srcs = ["python/ops/sort_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py index 2081a11f47d71106f8e57227f46639717a791855..3f592611830e40a30392239c85486a2fad15a2a2 100644 --- a/tensorflow/contrib/framework/__init__.py +++ b/tensorflow/contrib/framework/__init__.py @@ -37,6 +37,7 @@ See the @{$python/contrib.framework} guide. @@arg_scope @@add_arg_scope +@@current_arg_scope @@has_arg_scope @@arg_scoped_arguments @@ -78,6 +79,8 @@ See the @{$python/contrib.framework} guide. @@load_embedding_initializer @@load_linear_multiclass_bias_initializer @@load_variable_slot_initializer + +@@sort """ from __future__ import absolute_import diff --git a/tensorflow/contrib/framework/python/framework/__init__.py b/tensorflow/contrib/framework/python/framework/__init__.py index c8e6a4685498a4d89cef44f6a9a3acbe7557cb67..2d49771ab756359712a3ee0b23649c231678f952 100644 --- a/tensorflow/contrib/framework/python/framework/__init__.py +++ b/tensorflow/contrib/framework/python/framework/__init__.py @@ -21,6 +21,7 @@ from __future__ import print_function # pylint: disable=wildcard-import from tensorflow.contrib.framework.python.framework.checkpoint_utils import * from tensorflow.contrib.framework.python.framework.experimental import experimental +from tensorflow.contrib.framework.python.framework.graph_util import * from tensorflow.contrib.framework.python.framework.tensor_util import * # pylint: enable=wildcard-import from tensorflow.python.util import decorator_utils diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py new file mode 100644 index 0000000000000000000000000000000000000000..8ab8711db4650921e0d366a91adfe2f68b5a42f9 --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util.py @@ -0,0 +1,128 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helpers to manipulate a tensor graph in python. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import copy +import six + +# pylint: disable=unused-import +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present +from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes +from tensorflow.python.framework.graph_util_impl import _extract_graph_summary +from tensorflow.python.framework.graph_util_impl import _node_name + +__all__ = ["fuse_op"] + + +def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes, + output_quantized, op_name, op_type): + """Fuse subgraph between input_nodes and output_nodes into a single custom op. + + Args: + graph_def: A graph_pb2.GraphDef proto. + input_nodes: input nodes to the subgraph to be fused. + output_nodes: output nodes to the subgraph to be fused. + output_dtypes: A list of output datatypes for the custom op + output_quantized: A boolean flag that indicates if output is quantized + op_name: fused op name. + op_type: fused op type. + Returns: + The GraphDef of the new graph. + + Raises: + TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. + """ + + if not isinstance(graph_def, graph_pb2.GraphDef): + raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") + + if isinstance(input_nodes, six.string_types): + raise TypeError("input_nodes must be a list.") + + if isinstance(output_nodes, six.string_types): + raise TypeError("output_nodes must be a list.") + + name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( + graph_def) + _assert_nodes_are_present(name_to_node, input_nodes + output_nodes) + + # Nodes upto and including input_nodes + reachable_by_input = _bfs_for_reachable_nodes(input_nodes, name_to_input_name) + # Nodes upto and including output_nodes + reachable_by_output = _bfs_for_reachable_nodes(output_nodes, + name_to_input_name) + + # Set of nodes in the list input_nodes + input_nodes_set = set(input_nodes) + + # Set of nodes in the list output_nodes + output_nodes_set = set(output_nodes) + + nodes_post_output = [] + for node in graph_def.node: + n = _node_name(node.name) + if n in reachable_by_output: + if n not in reachable_by_input and n not in output_nodes_set: + # n is between input and output, i.e., part of the fused op + next_to_visit = [n] + while next_to_visit: + cur_node = next_to_visit[0] + del next_to_visit[0] + if cur_node in reachable_by_input and cur_node not in input_nodes_set: + raise TypeError("Node %s uses input %s not in input_nodes." % + (n, cur_node)) + if cur_node not in input_nodes_set: + next_to_visit += name_to_input_name[cur_node] + else: + nodes_post_output.append(n) + + # Add all nodes upto the input nodes + out = graph_pb2.GraphDef() + reachable_by_input_sorted = sorted( + list(reachable_by_input), key=lambda n: name_to_seq_num[n]) + for node in reachable_by_input_sorted: + out.node.extend([copy.deepcopy(name_to_node[node])]) + + # Add the custom op + new_node = node_def_pb2.NodeDef() + for node in input_nodes: + new_node.input.append(node) + new_node.attr["_output_types"].list.type[:] = output_dtypes + new_node.attr["_output_quantized"].b = output_quantized + new_node.op = op_type + new_node.name = op_name + out.node.extend([new_node]) + + # Add the nodes in the output of the custom op + for index, n in enumerate(output_nodes): + assert len(name_to_node[n].input) == 1 + new_node = copy.deepcopy(name_to_node[n]) + del new_node.input[:] + new_node.input.append(op_name + (":" + str(index) if index != 0 else "")) + out.node.extend([new_node]) + + # Add the nodes post output_nodes + for n in nodes_post_output: + out.node.extend([copy.deepcopy(name_to_node[n])]) + + out.library.CopyFrom(graph_def.library) + out.versions.CopyFrom(graph_def.versions) + return out diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py new file mode 100644 index 0000000000000000000000000000000000000000..87b992e22e1ad3aa20389d0834eeb3a5972c676e --- /dev/null +++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""@graph_util tests.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.framework import graph_util +from tensorflow.core.framework import graph_pb2 +from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.python.platform import test + + +def GetNewNode(name, op, input_nodes): + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for node in input_nodes: + new_node.input.append(node) + return new_node + + +class GraphUtilTest(test.TestCase): + + def testGraphUtil(self): + graph_def = graph_pb2.GraphDef() + node_a = GetNewNode('A', 'Placeholder', []) + node_b = GetNewNode('B', 'Op1', ['A']) + node_c = GetNewNode('C', 'Op1', ['B']) + node_d = GetNewNode('D', 'Op1', ['C']) + node_e = GetNewNode('E', 'Op1', ['D']) + graph_def.node.extend([node_a, node_b, node_c, node_d, node_e]) + fused_graph_def = graph_util.fuse_op( + graph_def, ['A'], ['D'], [types_pb2.DT_FLOAT], True, 'FusedOp', 'Op2') + self.assertEqual(len(fused_graph_def.node), 4) + self.assertEqual(fused_graph_def.node[0].name, 'A') + self.assertEqual(fused_graph_def.node[1].name, 'FusedOp') + self.assertEqual(fused_graph_def.node[1].input[0], 'A') + self.assertEqual(fused_graph_def.node[1].op, 'Op2') + self.assertEqual(fused_graph_def.node[1].attr['_output_quantized'].b, True) + self.assertEqual(fused_graph_def.node[1].attr['_output_types'].list.type, + [types_pb2.DT_FLOAT]) + self.assertEqual(fused_graph_def.node[2].name, 'D') + self.assertEqual(fused_graph_def.node[3].name, 'E') + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/framework/tensor_util.py b/tensorflow/contrib/framework/python/framework/tensor_util.py index e595e4d90bfd6bef9de5ac0724a18060e7458f8e..4e6eea8884731f3e14a7ae817296c3782d943527 100644 --- a/tensorflow/contrib/framework/python/framework/tensor_util.py +++ b/tensorflow/contrib/framework/python/framework/tensor_util.py @@ -77,10 +77,10 @@ def reduce_sum_n(tensors, name=None): return tensors[0] return math_ops.add_n(tensors, name=name_scope) -@deprecated(None, - "Please switch to tf.confusion_matrix.remove_squeezable_dimensions. Note " - "that order of the inputs and ouputs of labels and predictions have also " - "been switched.") +@deprecated( + None, "Please switch to remove_squeezable_dimensions from " + "tf.confusion_matrix. Note that the order of the inputs and outputs of " + "labels and predictions have also been switched.") def remove_squeezable_dimensions(predictions, labels, name=None): """Squeeze last dim if ranks of `predictions` and `labels` differ by 1. diff --git a/tensorflow/contrib/framework/python/ops/__init__.py b/tensorflow/contrib/framework/python/ops/__init__.py index edef37cf0c0719bf10a4c75c34adb30b9716cdcd..685bb94779762ce46ee342e7e0a182c54be64743 100644 --- a/tensorflow/contrib/framework/python/ops/__init__.py +++ b/tensorflow/contrib/framework/python/ops/__init__.py @@ -24,5 +24,6 @@ from tensorflow.contrib.framework.python.ops.arg_scope import * from tensorflow.contrib.framework.python.ops.checkpoint_ops import * from tensorflow.contrib.framework.python.ops.ops import * from tensorflow.contrib.framework.python.ops.prettyprint_ops import * +from tensorflow.contrib.framework.python.ops.sort_ops import * from tensorflow.contrib.framework.python.ops.variables import * # pylint: enable=wildcard-import diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..a0667bd489213cf366e27114a91e8699ed9e7428 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2.py @@ -0,0 +1,111 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops that will eventually be folded into tensorflow/python/ops/math_ops.py +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +from tensorflow.python.eager import context +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import gen_math_ops +from tensorflow.python.ops import math_ops + + + +def accumulate_n_v2(inputs, shape=None, tensor_dtype=None, name=None): + """Returns the element-wise sum of a list of tensors. + + Optionally, pass `shape` and `tensor_dtype` for shape and type checking, + otherwise, these are inferred. + + `tf.accumulate_n_v2` performs the same operation as `tf.add_n`, but does not + wait for all of its inputs to be ready before beginning to sum. This can + save memory if inputs are ready at different times, since minimum temporary + storage is proportional to the output size rather than the inputs size. + + Unlike the original `accumulate_n`, `accumulate_n_v2` is differentiable. + + For example: + + ```python + a = tf.constant([[1, 2], [3, 4]]) + b = tf.constant([[5, 0], [0, 6]]) + tf.accumulate_n_v2([a, b, a]) # [[7, 4], [6, 14]] + + # Explicitly pass shape and type + tf.accumulate_n_v2([a, b, a], shape=[2, 2], tensor_dtype=tf.int32) + # [[7, 4], + # [6, 14]] + ``` + + Args: + inputs: A list of `Tensor` objects, each with same shape and type. + shape: Shape of elements of `inputs`. + tensor_dtype: The type of `inputs`. + name: A name for the operation (optional). + + Returns: + A `Tensor` of same shape and type as the elements of `inputs`. + + Raises: + ValueError: If `inputs` don't all have same shape and dtype or the shape + cannot be inferred. + """ + _INPUTS_ERR_MSG = ValueError("inputs must be a list of at least one Tensor" + "with the same dtype and shape") + if not inputs or not isinstance(inputs, (list, tuple)): + raise _INPUTS_ERR_MSG + inputs = ops.convert_n_to_tensor_or_indexed_slices(inputs) + if not all(isinstance(x, ops.Tensor) for x in inputs): + raise _INPUTS_ERR_MSG + if not all(x.dtype == inputs[0].dtype for x in inputs): + raise _INPUTS_ERR_MSG + if shape is not None: + shape = tensor_shape.as_shape(shape) + else: + shape = tensor_shape.unknown_shape() + for input_tensor in inputs: + if isinstance(input_tensor, ops.Tensor): + shape = shape.merge_with(input_tensor.get_shape()) + + # tensor_dtype is for safety only; operator's output type computed in C++ + if tensor_dtype is not None and tensor_dtype != inputs[0].dtype: + raise TypeError("tensor_dtype is {}, but input is of type {}" + .format(tensor_dtype, inputs[0].dtype)) + + if len(inputs) == 1 and name is None: + return inputs[0] + elif len(inputs) == 1 and name is not None: + return array_ops.identity(inputs[0], name=name) + elif context.in_eager_mode(): + # TemporaryVariable not currently supported in eager mode; fall back + # onto AddN for now. + # TODO(frreiss) remove this once the lifetime of eager variables gets + # addressed + return math_ops.add_n(inputs, name=name) + else: + return gen_math_ops._accumulate_nv2(inputs, name=name, shape=shape) + +# The following code should eventually be merged into +# tensorflow/python/ops/math_grad.py +@ops.RegisterGradient("AccumulateNV2") +def _AddNGrad(op, grad): + """Same as gradient for AddN. Copies the gradient to all inputs.""" + # Not broadcasting. + return [grad] * len(op.inputs) + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c2229bb8ad3d5b38321d16f150ed94175ab9bdbe --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_eager_test.py @@ -0,0 +1,85 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`. + +These test cases spefically exercise the `eager` APIs. They need to be in a +separate file from the remaining tests because eager mode is currently something +you can turn on but can't turn off for the lifetime of the current process.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.eager import backprop +from tensorflow.python.eager import context as eager_context +from tensorflow.python.eager import tape + + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.platform import test + + + +class AccumulateNV2EagerTest(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testMinimalEagerMode(self): + forty = constant_op.constant(40) + two = constant_op.constant(2) + answer = av2.accumulate_n_v2([forty, two]) + self.assertEqual(42, answer.numpy()) + + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).numpy()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).numpy()) + + def testGrad(self): + np.random.seed(42) + num_inputs = 3 + input_vars = [ + resource_variable_ops.ResourceVariable(10.0 * np.random.random(), + name="t%d" % i) + for i in range(0, num_inputs) + ] + + def fn(first, second, third): + return av2.accumulate_n_v2([first, second, third]) + + grad_fn = backprop.gradients_function(fn) + grad = grad_fn(input_vars[0], input_vars[1], input_vars[2]) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [elem.numpy() for elem in grad]) + + + +if __name__ == "__main__": + ops.enable_eager_execution() + test.main() + diff --git a/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3386e849d5cb8516ab3b1f6cb0429be3fc2fc960 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/accumulate_n_v2_test.py @@ -0,0 +1,123 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for new version of accumulate_n op that will eventually go into +`ops.math_ops`.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import accumulate_n_v2 as av2 + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes as dtypes_lib +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gradients +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + + +class AccumulateNV2Test(test_util.TensorFlowTestCase): + """Tests of the new, differentiable version of accumulate_n""" + + def testFloat(self): + np.random.seed(12345) + x = [np.random.random((1, 2, 3, 4, 5)) - 0.5 for _ in range(5)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllClose(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllClose(x[0] * 5, av2.accumulate_n_v2([tf_x[0]] * 5).eval()) + + def testInt(self): + np.random.seed(54321) + x = [np.random.randint(-128, 128, (5, 4, 3, 2, 1)) for _ in range(6)] + tf_x = ops.convert_n_to_tensor(x) + with self.test_session(use_gpu=True): + self.assertAllEqual(sum(x), av2.accumulate_n_v2(tf_x).eval()) + self.assertAllEqual(x[0] * 6, av2.accumulate_n_v2([tf_x[0]] * 6).eval()) + + def testGrad(self): + np.random.seed(42) + for num_inputs in range(1, 10): + with self.test_session(use_gpu=True) as sess: + input_vars = [ + variables.Variable(10.0 * np.random.random()) + for i in range(0, num_inputs) + ] + accum_n = av2.accumulate_n_v2(input_vars) + sess.run(variables.global_variables_initializer()) + accum_n_grad = gradients.gradients(accum_n, input_vars) + self.assertAllEqual(np.repeat(1.0, num_inputs), # d/dx (x + y + ...) = 1 + [g.eval() for g in accum_n_grad]) + + # The tests below used to be in a separate class under cwise_ops_test.py, + # which did not run in the default test target. + # Putting them here so that everything that exercises AccumulateNV2 is in + # one place and the default build runs all unit tests. + def testSimple(self): + with self.test_session(): + random_arrays = [ + np.random.rand(16, 16, 16, 16).astype(np.float32) for _ in range(20) + ] + random_tensors = [ + ops.convert_to_tensor( + x, dtype=dtypes_lib.float32) for x in random_arrays + ] + tf_val = av2.accumulate_n_v2(random_tensors) + np_val = random_arrays[0] + for random_array in random_arrays[1:]: + np_val += random_array + self.assertAllClose(np_val, tf_val.eval()) + + def testZeroArgs(self): + with self.test_session(): + with self.assertRaises(ValueError): + tf_val = av2.accumulate_n_v2([]) + tf_val.eval() + + def testWrongShape(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(0.2) + b = variables.Variable(0.1) + tf_val = av2.accumulate_n_v2([a,b], shape=[2,2]) # Should be shape=[] + + def testIncompatibleShapes(self): + with self.test_session(): + with self.assertRaises(ValueError): + a = variables.Variable(np.array([0.1,0.2])) + b = variables.Variable(np.array([[0.3],[0.4]])) + tf_val = av2.accumulate_n_v2([a,b]) + + def testWrongType(self): + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + b = variables.Variable(0.1, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a,b], tensor_dtype=np.int32) + + def testWrongTypeOneInput(self): + # Scenario that used to trigger a bug, even when testWrongType() worked + with self.test_session(): + with self.assertRaises(TypeError): + a = variables.Variable(0.2, dtype=np.float32) + tf_val = av2.accumulate_n_v2([a], tensor_dtype=np.int32) + + +if __name__ == "__main__": + googletest.main() diff --git a/tensorflow/contrib/framework/python/ops/arg_scope.py b/tensorflow/contrib/framework/python/ops/arg_scope.py index 9c194ec202ab6150278b26e844b9d3e97a7d6761..2bce00fde2459878a12027bb4d98bd3818bc92a2 100644 --- a/tensorflow/contrib/framework/python/ops/arg_scope.py +++ b/tensorflow/contrib/framework/python/ops/arg_scope.py @@ -67,6 +67,7 @@ from tensorflow.python.util import tf_decorator __all__ = ['arg_scope', 'add_arg_scope', + 'current_arg_scope', 'has_arg_scope', 'arg_scoped_arguments'] @@ -83,7 +84,7 @@ def _get_arg_stack(): return _ARGSTACK -def _current_arg_scope(): +def current_arg_scope(): stack = _get_arg_stack() return stack[-1] @@ -144,7 +145,7 @@ def arg_scope(list_ops_or_scope, **kwargs): raise TypeError('list_ops_or_scope must either be a list/tuple or reused' 'scope (i.e. dict)') try: - current_scope = _current_arg_scope().copy() + current_scope = current_arg_scope().copy() for op in list_ops_or_scope: key_op = _key_op(op) if not has_arg_scope(op): @@ -172,7 +173,7 @@ def add_arg_scope(func): A tuple with the decorated function func_with_args(). """ def func_with_args(*args, **kwargs): - current_scope = _current_arg_scope() + current_scope = current_arg_scope() current_args = kwargs key_func = _key_op(func) if key_func in current_scope: diff --git a/tensorflow/contrib/framework/python/ops/sort_ops.py b/tensorflow/contrib/framework/python/ops/sort_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8f62f0ea7b9b561f235b9496ffda97a9f378d530 --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops.py @@ -0,0 +1,113 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for sorting tensors. + +@@sort +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops as framework_ops +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops + + +def sort(values, axis=-1, direction='ASCENDING', name=None): + """Sorts a tensor. + + Args: + values: 1-D or higher numeric `Tensor`. + axis: The axis along which to sort. The default is -1, which sorts the last + axis. + direction: The direction in which to sort the values (`'ASCENDING'` or + `'DESCENDING'`). + name: Optional name for the operation. + + Returns: + A `Tensor` with the same dtype and shape as `values`, with the elements + sorted along the given `axis`. + + Raises: + ValueError: If axis is not a constant scalar, or the direction is invalid. + """ + with framework_ops.name_scope(name, 'sort'): + if direction not in _SORT_IMPL: + raise ValueError('%s should be one of %s' % + (direction, ', '.join(sorted(_SORT_IMPL.keys())))) + # Axis must be an integer, not a Tensor. + axis = framework_ops.convert_to_tensor(axis, name='axis') + axis_static = tensor_util.constant_value(axis) + if axis.shape.ndims != 0 or axis_static is None: + raise ValueError('axis must be a constant scalar') + axis_static = int(axis_static) # Avoids NumPy casting error + + values = framework_ops.convert_to_tensor(values, name='values') + + return _SORT_IMPL[direction](values, axis_static) + + +def _descending_sort(values, axis): + """Sorts values in reverse using `top_k`. + + Args: + values: Tensor of numeric values. + axis: Index of the axis which values should be sorted along. + + Returns: + The sorted values. + """ + k = array_ops.shape(values)[axis] + rank = array_ops.rank(values) + # Fast path: sorting the last axis. + if axis == -1 or axis + 1 == values.get_shape().ndims: + return nn_ops.top_k(values, k)[0] + + # Otherwise, transpose the array. Swap axes `axis` and `rank - 1`. + if axis < 0: + # Make axis a Tensor with the real axis index if needed. + axis += rank + transposition = array_ops.concat( + [ + # Axes up to axis are unchanged. + math_ops.range(axis), + # Swap axis and rank - 1. + [rank - 1], + # Axes in [axis + 1, rank - 1) are unchanged. + math_ops.range(axis + 1, rank - 1), + # Swap axis and rank - 1. + [axis] + ], + axis=0) + top_k_input = array_ops.transpose(values, transposition) + values, unused_indices = nn_ops.top_k(top_k_input, k) + # transposition contains a single cycle of length 2 (swapping 2 elements), + # so it is an involution (it is its own inverse). + return array_ops.transpose(values, transposition) + + +def _ascending_sort(values, axis): + # Negate the values to get the ascending order from descending sort. + values_or_indices = _descending_sort(-values, axis) + return -values_or_indices + + +_SORT_IMPL = { + 'ASCENDING': _ascending_sort, + 'DESCENDING': _descending_sort, +} diff --git a/tensorflow/contrib/framework/python/ops/sort_ops_test.py b/tensorflow/contrib/framework/python/ops/sort_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..d08ae502f10d98ee14d8bea2f76b18bedb935cea --- /dev/null +++ b/tensorflow/contrib/framework/python/ops/sort_ops_test.py @@ -0,0 +1,95 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the sort wrapper.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import sort_ops +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import errors +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.platform import test + + +class SortTest(test.TestCase): + + def testRandom_lowDimensionality(self): + self._testRandom_lowDimensionality(negative_axis=False) + + def testRandom_lowDimensionality_negative(self): + self._testRandom_lowDimensionality(negative_axis=True) + + def _testRandom_lowDimensionality(self, negative_axis): + np.random.seed(42) + for _ in range(20): + rank = np.random.randint(1, 3) + shape = [np.random.randint(0, 20) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + if negative_axis: + sort_axis = -1 - sort_axis + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testRandom_highDimensionality(self): + np.random.seed(100) + for _ in range(20): + rank = np.random.randint(5, 15) + shape = [np.random.randint(1, 4) for _ in range(rank)] + arr = np.random.random(shape) + sort_axis = np.random.choice(rank) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=sort_axis), + sort_ops.sort(constant_op.constant(arr), axis=sort_axis).eval()) + + def testScalar(self): + # Create an empty scalar where the static shape is unknown. + zeros_length_1 = array_ops.zeros( + random_ops.random_uniform([1], minval=0, maxval=1, dtype=dtypes.int32), + dtype=dtypes.int32) + scalar = array_ops.zeros(zeros_length_1) + + sort = sort_ops.sort(scalar) + with self.test_session(): + with self.assertRaises(errors.InvalidArgumentError): + sort.eval() + + def testNegativeOutOfBounds_staticShape(self): + arr = constant_op.constant([3, 4, 5]) + with self.assertRaises(ValueError): + sort_ops.sort(arr, axis=-4) + + def testDescending(self): + arr = np.random.random((10, 5, 5)) + with self.test_session(): + self.assertAllEqual( + np.sort(arr, axis=0)[::-1], + sort_ops.sort( + constant_op.constant(arr), + axis=0, + direction='DESCENDING').eval()) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py index 1bd9a14a7f3e17b30b811b3b73e5915c0dd1ec59..b7668379686b4f0ba2a3e415ddb44b287659baaa 100644 --- a/tensorflow/contrib/framework/python/ops/variables.py +++ b/tensorflow/contrib/framework/python/ops/variables.py @@ -201,7 +201,7 @@ def variable(name, shape=None, dtype=None, initializer=None, else [ops.GraphKeys.GLOBAL_VARIABLES]) # Remove duplicates - collections = set(collections) + collections = list(set(collections)) getter = variable_scope.get_variable if custom_getter is not None: getter = functools.partial(custom_getter, diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD index 31917b40eb900dd6a0a6c1a83d00881dfe690c49..ce37672895b37275770d2f5410f662e9acf1de9d 100644 --- a/tensorflow/contrib/fused_conv/BUILD +++ b/tensorflow/contrib/fused_conv/BUILD @@ -38,7 +38,6 @@ tf_custom_op_py_library( ":fused_conv2d_bias_activation_op", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -49,6 +48,7 @@ tf_custom_op_py_library( "//tensorflow/python:nn_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:session", "//tensorflow/python:util", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -69,7 +69,7 @@ tf_kernel_library( "//tensorflow/core:lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor", - "//tensorflow/core/kernels:bounds_check_lib", + "//tensorflow/core/kernels:bounds_check", "//tensorflow/core/kernels:conv_2d_hdrs", "//tensorflow/core/kernels:conv_ops_gpu_hdrs", "//tensorflow/core/kernels:gpu_util_hdrs", diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc index 9275d5a22b2697c37414fba2f6176f708808e60c..88306094ab9947c9c78b03c0013f6afc88316803 100644 --- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc +++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc @@ -298,6 +298,17 @@ void LaunchFusedConv2DBiasActivationOp:: constexpr int rank = is_int8x4 ? 5 : 4; constexpr int vect = is_int8x4 ? 4 : 1; + if (is_int8x4) { + int cc_major, cc_minor; + stream->parent()->GetDeviceDescription().cuda_compute_capability(&cc_major, + &cc_minor); + OP_REQUIRES( + ctx, cc_major >= 6 && cc_minor >= 1, + errors::Unimplemented( + "FusedConv2DBiasActivation for int8 is only supported on GPUs with " + "compute capability 6.1 or later.")); + } + const int batch_size = GetTensorDim(conv_input_param, data_format, 'N'); int conv_input_rows = GetTensorDim(conv_input_param, data_format, 'H'); int conv_input_cols = GetTensorDim(conv_input_param, data_format, 'W'); @@ -434,11 +445,11 @@ void LaunchFusedConv2DBiasActivationOp:: .set_zero_padding_width(padding_cols / 2); Tensor maybe_transformed_filter; - const Tensor* filter; - if (is_int8x4) { - // We have already checked filter is OIHW_VECT_I in the constructor. - filter = &filter_param; - } else if (filter_format == FORMAT_HWIO) { + const Tensor* filter = &filter_param; + // For qint8, we have already checked filter is OIHW_VECT_I in the + // constructor, but we need to test for is_int8x4 so the if block doesn't + // generate code for qint8. + if (!is_int8x4 && filter_format == FORMAT_HWIO) { // Shuffle filter tensor from HWIO to OIHW: OP_REQUIRES_OK(ctx, ctx->allocate_temp( DataTypeToEnum::value, @@ -493,42 +504,37 @@ void LaunchFusedConv2DBiasActivationOp:: dnn::AlgorithmConfig algorithm_config; if (cudnn_use_autotune && !AutoTuneConvBiasActivation::GetInstance()->Find( fused_conv_parameters, &algorithm_config)) { - std::vector algorithms; + std::vector algorithms; CHECK(stream->parent()->GetConvolveAlgorithms( fused_conv_parameters.ShouldIncludeWinogradNonfusedAlgo(), &algorithms)); dnn::ProfileResult best_result; dnn::ProfileResult best_result_no_scratch; - // TODO(benbarsdell): Ideally this should not attempt using tensor op math - // if it's not enabled. - for (bool use_tensor_ops : {false, true}) { - for (auto algo_index : algorithms) { - // TODO(zhengxq): profile each algorithm multiple times to better - // accuracy. - dnn::AlgorithmDesc profile_algorithm(algo_index, use_tensor_ops); - CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); - dnn::ProfileResult profile_result; - bool cudnn_launch_status = - stream - ->ThenFusedConvolveWithAlgorithm( - conv_input_desc, conv_input_ptr, conv_input_scale, - filter_desc, filter_ptr, conv_desc, side_input_ptr, - side_input_scale, bias_desc, bias_ptr, - dnn::ActivationMode::kRelu, output_desc, &output_ptr, - &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), - &profile_result) - .ok(); - if (cudnn_launch_status) { - if (profile_result.is_valid()) { - if (profile_result.elapsed_time_in_ms() < - best_result.elapsed_time_in_ms()) { - best_result = profile_result; - } - if (scratch_allocator.TotalByteSize() == 0 && - profile_result.elapsed_time_in_ms() < - best_result_no_scratch.elapsed_time_in_ms()) { - best_result_no_scratch = profile_result; - } + for (auto profile_algorithm : algorithms) { + // TODO(zhengxq): profile each algorithm multiple times to better + // accuracy. + CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx); + dnn::ProfileResult profile_result; + bool cudnn_launch_status = + stream + ->ThenFusedConvolveWithAlgorithm( + conv_input_desc, conv_input_ptr, conv_input_scale, + filter_desc, filter_ptr, conv_desc, side_input_ptr, + side_input_scale, bias_desc, bias_ptr, + dnn::ActivationMode::kRelu, output_desc, &output_ptr, + &scratch_allocator, dnn::AlgorithmConfig(profile_algorithm), + &profile_result) + .ok(); + if (cudnn_launch_status) { + if (profile_result.is_valid()) { + if (profile_result.elapsed_time_in_ms() < + best_result.elapsed_time_in_ms()) { + best_result = profile_result; + } + if (scratch_allocator.TotalByteSize() == 0 && + profile_result.elapsed_time_in_ms() < + best_result_no_scratch.elapsed_time_in_ms()) { + best_result_no_scratch = profile_result; } } } diff --git a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py index 3b8f7d6ed760647c4c61ce5ea60be1d7d17ddfa0..2a18f3eeecc7e0e69c54b219886a263136f01b2c 100644 --- a/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py +++ b/tensorflow/contrib/fused_conv/python/ops/fused_conv2d_bias_activation_op_test.py @@ -159,9 +159,12 @@ class FusedConv2DBiasActivationTest(test.TestCase): def _DtypesToTest(self, use_gpu): return [dtypes.float32] + def _FilterFormatsToTest(self, use_gpu): + return ["HWIO", "OIHW"] + def _SetupValuesForDevice(self, tensor_in_sizes, filter_in_sizes, bias, strides, padding, activation_mode, data_format, - dtype): + filter_format, dtype): """Verifies the output values of the convolution function. Args: @@ -174,6 +177,7 @@ class FusedConv2DBiasActivationTest(test.TestCase): padding: Padding type. activation_mode: Activation mode. data_format: Format of the data tensors. + filter_format: Filter format to use for the fused convolution. dtype: Data type for inputs and outputs. Returns: Symbolic tensor value and reference value that can be used to @@ -192,6 +196,9 @@ class FusedConv2DBiasActivationTest(test.TestCase): with self.test_session(use_gpu=True): t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=dtype) t2 = constant_op.constant(x2, shape=filter_in_sizes, dtype=dtype) + fused_t2 = t2 + if filter_format == "OIHW": + fused_t2 = HwioToOihw(t2) t3 = constant_op.constant(x3, shape=[bias_size], dtype=dtype) strides = [1] + strides + [1] if data_format == "NCHW": @@ -199,11 +206,12 @@ class FusedConv2DBiasActivationTest(test.TestCase): strides = test_util.NHWCToNCHW(strides) output = fused_conv2d_bias_activation_op.fused_conv2d_bias_activation( t1, - t2, + fused_t2, t3, strides=strides, padding=padding, data_format=data_format, + filter_format=filter_format, activation_mode=activation_mode) ref_conv_output = nn_ops.conv2d( t1, t2, strides=strides, padding=padding, data_format=data_format) @@ -268,9 +276,10 @@ class FusedConv2DBiasActivationTest(test.TestCase): ref_tensors = [] for (data_format, use_gpu) in GetTestConfigs(): for dtype in self._DtypesToTest(use_gpu): - result, expected = self._SetupValuesForDevice( - tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", - data_format, dtype) + for filter_format in self._FilterFormatsToTest(use_gpu): + result, expected = self._SetupValuesForDevice( + tensor_in_sizes, filter_in_sizes, bias, strides, padding, "Relu", + data_format, filter_format, dtype) tensors.append(result) ref_tensors.append(expected) with self.test_session() as sess: @@ -607,6 +616,10 @@ def NchwToNchwVectC(in_tensor): return array_ops.transpose(t, [0, 1, 3, 4, 2]) +def HwioToOihw(in_tensor): + return array_ops.transpose(in_tensor, [3, 2, 0, 1]) + + def SimulateFusedConv2dBiasActivationInt8(conv_input_scale, conv_input, kernel, padding, strides, side_input_scale, side_input, biases): diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 54dbb11b6ebcfac8f8d687863f85a8d890fd4fb3..1418c87023af0dbff890f46e10f0140d5b89e4b7 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -14,6 +14,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":estimator", ":eval", ":features", ":losses", @@ -86,6 +87,17 @@ py_library( ], ) +py_library( + name = "estimator", + srcs = ["python/estimator/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":head", + "//tensorflow/python:util", + ], +) + py_library( name = "losses", srcs = ["python/losses/__init__.py"], @@ -190,6 +202,7 @@ py_library( "//tensorflow/python:embedding_ops", "//tensorflow/python:math_ops", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) @@ -222,6 +235,7 @@ py_library( "//tensorflow/python:nn", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", ], ) @@ -255,7 +269,10 @@ py_library( "python/features/python/clip_weights_impl.py", ], srcs_version = "PY2AND3", - deps = ["//tensorflow/contrib/opt:opt_py"], + deps = [ + "//tensorflow/contrib/opt:opt_py", + "//tensorflow/python:util", + ], ) py_test( @@ -369,6 +386,90 @@ py_test( ], ) +py_library( + name = "head", + srcs = [ + "python/estimator/python/head.py", + "python/estimator/python/head_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":train", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "head_test", + srcs = ["python/estimator/python/head_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "gan_estimator", + srcs = [ + "python/estimator/python/gan_estimator.py", + "python/estimator/python/gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + ":summaries", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "gan_estimator_test", + srcs = ["python/estimator/python/gan_estimator_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + tags = ["notsan"], + deps = [ + ":gan_estimator", + ":namedtuples", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/README.md b/tensorflow/contrib/gan/README.md index 10458a2458384c8f589183003256db24d69742d7..3ab84780705b35567169bd76fd3485ad355ba9d8 100644 --- a/tensorflow/contrib/gan/README.md +++ b/tensorflow/contrib/gan/README.md @@ -47,13 +47,14 @@ such as the Wasserstein loss, gradient penalty, mutual information penalty, etc * [evaluation](https://www.tensorflow.org/code/tensorflow/contrib/gan/python/eval/python/): Use `Inception Score` or `Frechet Distance` with a pretrained Inception -network to evaluate your unconditional generative model. You can also also use +network to evaluate your unconditional generative model. You can also use your own pretrained classifier for more specific performance numbers, or use other methods for evaluating conditional generative models. -* [examples](https://github.com/tensorflow/models/tree/master/gan/): +* examples (coming soon): See examples of how to use TFGAN to make GAN training easier, or use the more complicated examples to jumpstart your -own project. +own project. These include unconditional and conditional GANs, InfoGANs, +adversarial losses on existing networks, and image-to-image translation. ## Training a GAN model diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index 67eee771d040995449329dde0b0cb990793176ec..dff361fdc42708ea69999c2def4721f9d49fcf14 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # Collapse TFGAN into a tiered namespace. +from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features from tensorflow.contrib.gan.python import losses @@ -33,6 +34,7 @@ from tensorflow.contrib.gan.python.train import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'estimator', 'eval', 'features', 'losses', diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8c4a18228039cb4f2c06e0333f4b8408f1f631e9 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Collapse `estimator` into a single namespace. +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.gan.python.estimator.python import gan_estimator +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.head import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'gan_estimator', + 'head', +] + gan_estimator.__all__ + head.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..bc0e48540915d1de7e249f8640193366f37baa92 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..e89993991a389d68254a95aded2d771f4c2627be --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -0,0 +1,273 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum + +from tensorflow.contrib.framework.python.ops import variables as variable_lib +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import head as head_lib +from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope + + +__all__ = [ + 'GANEstimator', + 'SummaryType' +] + + +class SummaryType(enum.IntEnum): + NONE = 0 + VARIABLES = 1 + IMAGES = 2 + IMAGE_COMPARISON = 3 + + +_summary_type_map = { + SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, + SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries, + SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long +} + + +# TODO(joelshor): For now, this only supports 1:1 generator:discriminator +# training sequentially. Find a nice way to expose options to the user without +# exposing internals. +class GANEstimator(estimator.Estimator): + """An estimator for Generative Adversarial Networks (GANs). + + This Estimator is backed by TFGAN. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + gan_estimator = estimator.GANEstimator( + model_dir, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + + # Train estimator. + gan_estimator.train(train_input_fn, steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn) + + # Generate samples from generator. + predictions = np.array([ + x for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + model_dir=None, + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + add_summaries=None, + use_loss_summaries=True, + config=None): + """Initializes a GANEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + config: `RunConfig` object to configure the runtime settings. + """ + # TODO(joelshor): Explicitly validate inputs. + + def _model_fn(features, labels, mode): + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + gan_head = head_lib.gan_head( + generator_loss_fn, discriminator_loss_fn, gopt, dopt, + use_loss_summaries) + return _gan_model_fn( + features, labels, mode, generator_fn, discriminator_fn, gan_head, + add_summaries) + + super(GANEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +def _use_check_shapes(real_data): + """Determines whether TFGAN should check Tensor shapes.""" + return isinstance(real_data, ops.Tensor) + + +def _gan_model_fn( + features, + labels, + mode, + generator_fn, + discriminator_fn, + head, + add_summaries=None, + generator_scope_name='Generator'): + """The `model_fn` for the GAN estimator. + + We make the following convention: + features -> TFGAN's `generator_inputs` + labels -> TFGAN's `real_data` + + Args: + features: A dictionary to feed to generator. In the unconditional case, + this might be just `noise`. In the conditional GAN case, this + might be the generator's conditioning. The `generator_fn` determines + what the required keys are. + labels: Real data. Can be any structure, as long as `discriminator_fn` + can accept it for the first argument. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + generator_fn: A python lambda that takes `generator_inputs` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + head: A `Head` instance suitable for GANs. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + generator_scope_name: The name of the generator scope. We need this to be + the same for GANModels produced by TFGAN's `train.gan_model` and the + manually constructed ones for predictions. + + Returns: + `ModelFnOps` + + Raises: + ValueError: If `labels` isn't `None` during prediction. + """ + real_data = labels + generator_inputs = features + + if mode == model_fn_lib.ModeKeys.TRAIN: + gan_model = _make_train_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_model = _make_eval_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + else: + if real_data is not None: + raise ValueError('`labels` must be `None` when mode is `predict`. ' + 'Instead, found %s' % real_data) + gan_model = _make_prediction_gan_model( + generator_inputs, generator_fn, generator_scope_name) + + return head.create_estimator_spec( + features=None, + mode=mode, + logits=gan_model, + labels=None) + + +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + gan_model = tfgan_train.gan_model( + generator_fn, + discriminator_fn, + real_data, + generator_inputs, + generator_scope=generator_scope, + check_shapes=_use_check_shapes(real_data)) + if add_summaries: + if not isinstance(add_summaries, (tuple, list)): + add_summaries = [add_summaries] + with ops.name_scope(None): + for summary_type in add_summaries: + _summary_type_map[summary_type](gan_model) + + return gan_model + + +def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for evaluation.""" + return _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries) + + +def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): + """Make a `GANModel` from just the generator.""" + with variable_scope.variable_scope(generator_scope) as gen_scope: + generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access + generated_data = generator_fn(generator_inputs) + generator_variables = variable_lib.get_trainable_variables(gen_scope) + + return tfgan_tuples.GANModel( + generator_inputs, + generated_data, + generator_variables, + gen_scope, + generator_fn, + real_data=None, + discriminator_real_outputs=None, + discriminator_gen_outputs=None, + discriminator_variables=None, + discriminator_scope=None, + discriminator_fn=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..1bfdce9ee94d4d05d5186cd999361662bc0e3f85 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -0,0 +1,327 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's estimator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.learn.python.learn.learn_io import graph_io +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def generator_fn(noise_dict): + noise = noise_dict['x'] + return layers.fully_connected(noise, noise.shape[1].value) + + +def discriminator_fn(data, _): + return layers.fully_connected(data, 1) + + +def mock_head(testcase, expected_generator_inputs, expected_real_data, + generator_scope_name): + """Returns a mock head that validates logits values and variable names.""" + discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults + generator_var_names = set([ + '%s/fully_connected/weights:0' % generator_scope_name, + '%s/fully_connected/biases:0' % generator_scope_name]) + discriminator_var_names = set([ + '%s/fully_connected/weights:0' % discriminator_scope_name, + '%s/fully_connected/biases:0' % discriminator_scope_name]) + + def _create_estimator_spec(features, mode, logits, labels): + gan_model = logits # renaming for clarity + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + testcase.assertIsNone(features) + testcase.assertIsNone(labels) + testcase.assertIsInstance(gan_model, namedtuples.GANModel) + + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + expected_var_names = (generator_var_names if is_predict else + generator_var_names | discriminator_var_names) + testcase.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + assertions = [] + def _or_none(x): + return None if is_predict else x + testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) + # TODO(joelshor): Add check on `generated_data`. + testcase.assertItemsEqual( + generator_var_names, + set([x.name for x in gan_model.generator_variables])) + testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) + testcase.assertEqual(generator_fn, gan_model.generator_fn) + testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) + # TODO(joelshor): Add check on `discriminator_real_outputs`. + # TODO(joelshor): Add check on `discriminator_gen_outputs`. + if is_predict: + testcase.assertIsNone(gan_model.discriminator_scope) + else: + testcase.assertEqual(discriminator_scope_name, + gan_model.discriminator_scope.name) + testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) + + with ops.control_dependencies(assertions): + if mode == model_fn_lib.ModeKeys.TRAIN: + return model_fn_lib.EstimatorSpec( + mode=mode, loss=array_ops.zeros([]), + train_op=control_flow_ops.no_op(), training_hooks=[]) + elif mode == model_fn_lib.ModeKeys.EVAL: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data, + loss=array_ops.zeros([])) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + testcase.fail('Invalid mode: {}'.format(mode)) + + head = test.mock.NonCallableMagicMock(spec=head_lib._Head) + head.create_estimator_spec = test.mock.MagicMock( + wraps=_create_estimator_spec) + + return head + + +class GANModelFnTest(test.TestCase): + """Tests that _gan_model_fn passes expected logits to mock head.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits_helper(self, mode): + """Tests that the expected logits are passed to mock head.""" + with ops.Graph().as_default(): + training_util.get_or_create_global_step() + generator_inputs = {'x': array_ops.zeros([5, 4])} + real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else + array_ops.zeros([5, 4])) + generator_scope_name = 'generator' + head = mock_head(self, + expected_generator_inputs=generator_inputs, + expected_real_data=real_data, + generator_scope_name=generator_scope_name) + estimator_spec = estimator._gan_model_fn( + features=generator_inputs, + labels=real_data, + mode=mode, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_scope_name=generator_scope_name, + head=head) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + if mode == model_fn_lib.ModeKeys.TRAIN: + sess.run(estimator_spec.train_op) + elif mode == model_fn_lib.ModeKeys.EVAL: + sess.run(estimator_spec.loss) + elif mode == model_fn_lib.ModeKeys.PREDICT: + sess.run(estimator_spec.predictions) + else: + self.fail('Invalid mode: {}'.format(mode)) + + def test_logits_predict(self): + self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_logits_eval(self): + self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) + + def test_logits_train(self): + self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + + +# TODO(joelshor): Add pandas test. +class GANEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.GANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([x for x in est.predict(predict_input_fn)]) + + self.assertAllEqual(prediction_size, predictions.shape) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim]) + + def test_numpy_input_fn_lrdecay(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim], + lr_decay=True) + + def test_input_fn_from_parse_example(self): + """Tests complete flow with input_fn constructed from parse_example.""" + input_dim = 4 + batch_size = 6 + data = np.zeros([batch_size, input_dim]) + + serialized_examples = [] + for datum in data: + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'x': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + 'y': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + } + def _train_input_fn(): + feature_map = parsing_ops.parse_example( + serialized_examples, feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _eval_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _predict_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + features.pop('y') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + prediction_size=[batch_size, input_dim]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py new file mode 100644 index 0000000000000000000000000000000000000000..3225d6f41a1c17bfc8c57494dd683aaab45b10f3 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`'s loss.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import head_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.head_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = head_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..204c646e194319c0e63599da0b2a4909ef270ef3 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -0,0 +1,206 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head +from tensorflow.python.framework import ops + +__all__ = [ + 'GANHead', + 'gan_head', +] + + +def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, + discriminator_optimizer, use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """Creates a `GANHead`. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + + Returns: + An instance of `GANHead`. + """ + return GANHead(generator_loss_fn=generator_loss_fn, + discriminator_loss_fn=discriminator_loss_fn, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + use_loss_summaries=use_loss_summaries, + get_hooks_fn=get_hooks_fn, + name=name) + + +class GANHead(head._Head): # pylint: disable=protected-access + """`Head` for a GAN.""" + + def __init__(self, generator_loss_fn, discriminator_loss_fn, + generator_optimizer, discriminator_optimizer, + use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """`Head` for GAN training. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + """ + # TODO(joelshor): Validate inputs. + + if use_loss_summaries in [True, False]: + generator_loss_fn = functools.partial( + generator_loss_fn, add_summaries=use_loss_summaries) + discriminator_loss_fn = functools.partial( + discriminator_loss_fn, add_summaries=use_loss_summaries) + self._generator_loss_fn = generator_loss_fn + self._discriminator_loss_fn = discriminator_loss_fn + self._generator_optimizer = generator_optimizer + self._discriminator_optimizer = discriminator_optimizer + self._get_hooks_fn = get_hooks_fn + + @property + def name(self): + return self._name + + @property + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + """Returns a GANLoss tuple from the provided GANModel. + + See `Head` for more details. + + Args: + features: Input `dict` of `Tensor` objects. Unused. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + + Returns: + A GANLoss tuple. + + """ + _validate_logits_and_labels(logits, labels) + del mode, labels, features # unused for this head. + gan_model = logits # rename variable for clarity + return tfgan_tuples.GANLoss( + generator_loss=self._generator_loss_fn(gan_model), + discriminator_loss=self._discriminator_loss_fn(gan_model)) + + def create_estimator_spec( + self, features, mode, logits, labels=None, + train_op_fn=tfgan_train.gan_train_ops): + """Returns `EstimatorSpec` that a model_fn can return. + + See `Head` for more details. + + Args: + features: Must be `None`. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, + and discriminator optimizer, and returns a `GANTrainOps` tuple. For + example, this function can come from TFGAN's `train.py` library, or can + be custom. + + Returns: + `EstimatorSpec`. + + Raises: + ValueError: If `features` isn't `None`. + ValueError: If `train_op_fn` isn't provided in train mode. + """ + _validate_logits_and_labels(logits, labels) + if features is not None: + raise ValueError('`features` should be `None`. Instead, found: %s' % + features) + gan_model = logits # rename variable for clarity + with ops.name_scope('GANHead'): + if mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions=gan_model.generated_data) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = self.create_loss( + features=None, mode=mode, logits=gan_model, labels=None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + # TODO(joelshor): Add metrics. If head name provided, append it to + # metric keys. + eval_metric_ops={}) + elif mode == model_fn_lib.ModeKeys.TRAIN: + if train_op_fn is None: + raise ValueError('train_op_fn can not be None.') + gan_loss = self.create_loss(None, mode, gan_model, None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, + self._discriminator_optimizer) + training_hooks = self._get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) + else: + raise ValueError('Mode not recognized: %s' % mode) + + +def _validate_logits_and_labels(logits, labels): + if labels is not None: + raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' + 'be `None`. Instead, found: %s' % labels) + + if not isinstance(logits, tfgan_tuples.GANModel): + raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' + 'be an instnace of a `GANModel`. Instead, found: %s' % + logits) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py new file mode 100644 index 0000000000000000000000000000000000000000..8168f005cd1105886390a2384a936663c83fa5f5 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's head.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=None, + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +class GANHeadTest(test.TestCase): + + def setUp(self): + super(GANHeadTest, self).setUp() + self.gan_head = head.gan_head( + generator_loss_fn=dummy_loss, + discriminator_loss_fn=dummy_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + self.assertTrue(isinstance(self.gan_head, head.GANHead)) + + def _test_modes_helper(self, mode): + self.gan_head.create_estimator_spec( + features=None, + mode=mode, + logits=get_gan_model()) + + def test_modes_predict(self): + self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_modes_eval(self): + self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) + + def test_modes_train(self): + self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 151fecdca0ce4045b95c48d9db051d9f0903c96c..bb65f05b5a17e9a872e41d1dcb05aeb3cd6f6f40 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -16,6 +16,11 @@ These methods come from https://arxiv.org/abs/1606.03498 and https://arxiv.org/abs/1706.08500. + +NOTE: This implementation uses the same weights as in +https://github.com/openai/improved-gan/blob/master/inception_score/model.py, +but is more numerically stable and is an unbiased estimator of the true +Inception score even when splitting the inputs into batches. """ from __future__ import absolute_import @@ -54,17 +59,16 @@ __all__ = [ 'classifier_score', 'frechet_inception_distance', 'frechet_classifier_distance', + 'INCEPTION_DEFAULT_IMAGE_SIZE', ] -INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v3_2017_09_13.tar.gz' -INCEPTION_FROZEN_GRAPH = 'frozen_inception_v3.pb' -INCEPTION_V3_INPUT = 'inputs' -INCEPTION_V3_OUTPUT = 'InceptionV3/Logits/SpatialSqueeze:0' -INCEPTION_V3_FINAL_POOL = 'InceptionV3/Logits/AvgPool_1a_8x8/AvgPool:0' -_INCEPTION_V3_NUM_CLASSES = 1001 -_INCEPTION_V3_FINAL_POOL_SIZE = 2048 -INCEPTION_V3_DEFAULT_IMG_SIZE = 299 +INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' +INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' +INCEPTION_INPUT = 'Mul:0' +INCEPTION_OUTPUT = 'logits:0' +INCEPTION_FINAL_POOL = 'pool_3:0' +INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): @@ -75,12 +79,13 @@ def _validate_images(images, image_size): return images -def _matrix_square_root(mat, eps=1e-10): - """Compute symmetric square root of matrix. +def _symmetric_matrix_square_root(mat, eps=1e-10): + """Compute square root of a symmetric matrix. + + Note that this is different from an elementwise square root. We want to + compute M' where M' = sqrt(mat) such that M' * M' = mat. - Equivalent to matrix square root when matrix is invertible; note that this is - different from an elementwise square root. We want to compute M' where M' = - sqrt(mat) such that M' * M' = mat. + Also note that this method **only** works for symmetric matrices. Args: mat: Matrix to take the square root of. @@ -101,46 +106,37 @@ def _matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -# Convenience preprocessing function, with fixed defaults. -# NOTE: Floating-point inputs are expected to be in [0, 1]. -# Copied from /tensorflow_models/slim/preprocessing/inception_preprocessing.py. def preprocess_image( - image, height=INCEPTION_V3_DEFAULT_IMG_SIZE, - width=INCEPTION_V3_DEFAULT_IMG_SIZE, central_fraction=0.875, scope=None): - """Prepare one image for evaluation. + images, height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): + """Prepare a batch of images for evaluation. - If height and width are specified it would output an image with that size by - applying resize_bilinear. + This is the preprocessing portion of the graph from + http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz. - If central_fraction is specified it would crop the central fraction of the - input image. + Note that it expects Tensors in [0, 255]. This function maps pixel values to + [-1, 1] and resizes to match the InceptionV1 network. Args: - image: 3-D Tensor of image. If dtype is tf.float32 then the range should be - [0, 1], otherwise it would converted to tf.float32 assuming that the range - is [0, MAX], where MAX is largest positive representable number for - int(8/16/32) data type (see `tf.image.convert_image_dtype` for details). - height: integer - width: integer - central_fraction: Optional Float, fraction of the image to crop. + images: 3-D or 4-D Tensor of images. Values are in [0, 255]. + height: Integer. Height of resized output image. + width: Integer. Width of resized output image. scope: Optional scope for name_scope. + Returns: - 3-D float Tensor of prepared image. + 3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1]. """ - with ops.name_scope(scope, 'eval_image', [image, height, width]): - if image.dtype != dtypes.float32: - image = image_ops.convert_image_dtype(image, dtype=dtypes.float32) - # Crop the central region of the image with an area containing 87.5% of - # the original image. - image = image_ops.central_crop(image, central_fraction=central_fraction) - - # Resize the image to the specified height and width. - image = array_ops.expand_dims(image, 0) - image = image_ops.resize_bilinear(image, [height, width], - align_corners=False) - image = array_ops.squeeze(image, [0]) - image = (image - 0.5) * 2.0 - return image + is_single = images.shape.ndims == 3 + with ops.name_scope(scope, 'preprocess', [images, height, width]): + if not images.dtype.is_floating: + images = math_ops.to_float(images) + if is_single: + images = array_ops.expand_dims(images, axis=0) + resized = image_ops.resize_bilinear(images, [height, width]) + resized = (resized - 128.0) / 128.0 + if is_single: + resized = array_ops.squeeze(resized, axis=0) + return resized def _kl_divergence(p, p_logits, q): @@ -210,9 +206,9 @@ def _default_graph_def_fn(): def run_inception(images, graph_def=None, default_graph_def_fn=_default_graph_def_fn, - image_size=INCEPTION_V3_DEFAULT_IMG_SIZE, - input_tensor=INCEPTION_V3_INPUT, - output_tensor=INCEPTION_V3_OUTPUT): + image_size=INCEPTION_DEFAULT_IMAGE_SIZE, + input_tensor=INCEPTION_INPUT, + output_tensor=INCEPTION_OUTPUT): """Run images through a pretrained Inception classifier. Args: @@ -301,7 +297,8 @@ def classifier_score(images, classifier_fn, num_batches=1): efficiently run them through the classifier network. Returns: - The classifier score. A floating-point scalar. + The classifier score. A floating-point scalar of the same type as the output + of `classifier_fn`. """ generated_images_list = array_ops.split( images, num_or_size_splits=num_batches) @@ -316,26 +313,77 @@ def classifier_score(images, classifier_fn, num_batches=1): name='RunClassifier') logits = array_ops.concat(array_ops.unstack(logits), 0) logits.shape.assert_has_rank(2) + + # Use maximum precision for best results. + logits_dtype = logits.dtype + if logits_dtype != dtypes.float64: + logits = math_ops.to_double(logits) + p = nn_ops.softmax(logits) q = math_ops.reduce_mean(p, axis=0) kl = _kl_divergence(p, logits, q) kl.shape.assert_has_rank(1) log_score = math_ops.reduce_mean(kl) + final_score = math_ops.exp(log_score) - return math_ops.exp(log_score) + if logits_dtype != dtypes.float64: + final_score = math_ops.cast(final_score, logits_dtype) + return final_score inception_score = functools.partial( classifier_score, classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_V3_OUTPUT)) + run_inception, output_tensor=INCEPTION_OUTPUT)) + + +def trace_sqrt_product(sigma, sigma_v): + """Find the trace of the positive sqrt of product of covariance matrices. + + '_symmetric_matrix_square_root' only works for symmetric matrices, so we + cannot just take _symmetric_matrix_square_root(sigma * sigma_v). + ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). + + Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. + We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) + Note the following properties: + (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) + => eigenvalues(A A B B) = eigenvalues (A B B A) + (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) + => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) + (iii) forall M: trace(M) = sum(eigenvalues(M)) + => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) + = sum(sqrt(eigenvalues(A B B A))) + = sum(eigenvalues(sqrt(A B B A))) + = trace(sqrt(A B B A)) + = trace(sqrt(A sigma_v A)) + A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** + use the _symmetric_matrix_square_root function to find the roots of these + matrices. + + Args: + sigma: a square, symmetric, real, positive semi-definite covariance matrix + sigma_v: same as sigma + + Returns: + The trace of the positive square root of sigma*sigma_v + """ + + # Note sqrt_sigma is called "A" in the proof above + sqrt_sigma = _symmetric_matrix_square_root(sigma) + + # This is sqrt(A sigma_v A) above + sqrt_a_sigmav_a = math_ops.matmul( + sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + + return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) def frechet_classifier_distance(real_images, generated_images, classifier_fn, num_batches=1): - """Classifier distance for evaluating a conditional generative model. + """Classifier distance for evaluating a generative model. This is based on the Frechet Inception distance, but for an arbitrary classifier. @@ -351,6 +399,13 @@ def frechet_classifier_distance(real_images, Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_images: Real images to use to compute Frechet Inception distance. generated_images: Generated images to use to compute Frechet Inception @@ -361,7 +416,8 @@ def frechet_classifier_distance(real_images, efficiently run them through the classifier network. Returns: - The Frechet Inception distance. A floating-point scalar. + The Frechet Inception distance. A floating-point scalar of the same type + as the output of `classifier_fn` """ real_images_list = array_ops.split( @@ -380,19 +436,24 @@ def frechet_classifier_distance(real_images, swap_memory=True, name='RunClassifier') + activations_dtype = activations.dtype # Split the activations by the real and generated images. real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) # Ensure the activations have the right shapes. real_a = array_ops.concat(array_ops.unstack(real_a), 0) gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) + if activations_dtype != dtypes.float64: + real_a = math_ops.to_double(real_a) + gen_a = math_ops.to_double(gen_a) + real_a.shape.assert_has_rank(2) gen_a.shape.assert_has_rank(2) # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_a, 0) m_v = math_ops.reduce_mean(gen_a, 0) - num_examples = math_ops.to_float(array_ops.shape(real_a)[0]) + num_examples = math_ops.to_double(array_ops.shape(real_a)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T sigma = math_ops.matmul( @@ -401,13 +462,20 @@ def frechet_classifier_distance(real_images, sigma_v = math_ops.matmul( gen_a - m_v, gen_a - m_v, transpose_a=True) / (num_examples - 1) - # Take matrix square root of the product of covariance matrices. - sqcc = _matrix_square_root(math_ops.matmul(sigma, sigma_v)) + # Find the Tr(sqrt(sigma sigma_v)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) # Compute the two components of FID. - trace = math_ops.trace(sigma + sigma_v - 2.0 * sqcc) + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + + # Next the distance between means. mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. fid = trace + mean + if activations_dtype != dtypes.float64: + fid = math_ops.cast(fid, activations_dtype) return fid @@ -415,4 +483,4 @@ def frechet_classifier_distance(real_images, frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( - run_inception, output_tensor=INCEPTION_V3_FINAL_POOL)) + run_inception, output_tensor=INCEPTION_FINAL_POOL)) diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 9e8776f3a4c59b167a5587b91a5a38e8c296a68c..92e0a995748c1c4c2ddfff0daae59be5a6eaefb4 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -23,6 +23,7 @@ import tarfile import tempfile import numpy as np +from scipy import linalg as scp_linalg from google.protobuf import text_format @@ -49,32 +50,25 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) -def _approximate_matrix_sqrt(mat, eps=1e-8): - # Unlike tensorflow, numpy's return order is (u, s, v) - u, s, v = np.linalg.svd(mat) - si = np.where(s < eps, s, np.sqrt(s)) - # Note the "v" returned by numpy is actually v = V^T - # (when referencing the SVD equation A = U S V^T) - # This is unlike Tensorflow which returns v = V - return np.dot(np.dot(u, np.diag(si)), v) - - def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) sigma = np.cov(real_imgs, rowvar=False) sigma_v = np.cov(gen_imgs, rowvar=False) - sqcc = _approximate_matrix_sqrt(np.dot(sigma, sigma_v)) + sqcc = scp_linalg.sqrtm(np.dot(sigma, sigma_v)) mean = np.square(m - m_v).sum() trace = np.trace(sigma + sigma_v - 2 * sqcc) fid = mean + trace return fid +def _expected_trace_sqrt_product(sigma, sigma_v): + return np.trace(scp_linalg.sqrtm(np.dot(sigma, sigma_v))) + # A dummy GraphDef string with the minimum number of Ops. graphdef_string = """ node { - name: "inputs" + name: "Mul" op: "Placeholder" attr { key: "dtype" @@ -103,7 +97,7 @@ node { } } node { - name: "InceptionV3/Logits/SpatialSqueeze" + name: "logits" op: "Placeholder" attr { key: "dtype" @@ -126,7 +120,7 @@ node { } } node { - name: "InceptionV3/Logits/AvgPool_1a_8x8/AvgPool" + name: "pool_3" op: "Placeholder" attr { key: "dtype" @@ -188,7 +182,7 @@ class ClassifierMetricsTest(test.TestCase): img = array_ops.ones([batch_size, 299, 299, 3]) pool = _run_with_mock( classifier_metrics.run_inception, img, - output_tensor=classifier_metrics.INCEPTION_V3_FINAL_POOL) + output_tensor=classifier_metrics.INCEPTION_FINAL_POOL) self.assertTrue(isinstance(pool, ops.Tensor)) pool.shape.assert_is_compatible_with([batch_size, 2048]) @@ -268,8 +262,11 @@ class ClassifierMetricsTest(test.TestCase): def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) - test_pool_real_a = np.float32(np.random.randn(64, 256)) - test_pool_gen_a = np.float32(np.random.randn(64, 256)) + + # Make num_examples > num_features to ensure scipy's sqrtm function + # doesn't return a complex matrix. + test_pool_real_a = np.float32(np.random.randn(512, 256)) + test_pool_gen_a = np.float32(np.random.randn(512, 256)) fid_op = _run_with_mock(classifier_metrics.frechet_classifier_distance, test_pool_real_a, test_pool_gen_a, @@ -280,13 +277,36 @@ class ClassifierMetricsTest(test.TestCase): expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a) - self.assertAllClose(expected_fid, actual_fid, 0.01) + self.assertAllClose(expected_fid, actual_fid, 0.0001) + + def test_trace_sqrt_product_value(self): + """Test that `trace_sqrt_product` gives the correct value.""" + np.random.seed(0) + + # Make num_examples > num_features to ensure scipy's sqrtm function + # doesn't return a complex matrix. + test_pool_real_a = np.float32(np.random.randn(512, 256)) + test_pool_gen_a = np.float32(np.random.randn(512, 256)) + + cov_real = np.cov(test_pool_real_a, rowvar=False) + cov_gen = np.cov(test_pool_gen_a, rowvar=False) + + trace_sqrt_prod_op = _run_with_mock(classifier_metrics.trace_sqrt_product, + cov_real, cov_gen) + + with self.test_session() as sess: + # trace_sqrt_product: tsp + actual_tsp = sess.run(trace_sqrt_prod_op) + + expected_tsp = _expected_trace_sqrt_product(cov_real, cov_gen) + + self.assertAllClose(actual_tsp, expected_tsp, 0.01) def test_preprocess_image_graph(self): """Test `preprocess_image` graph construction.""" incorrectly_sized_image = array_ops.zeros([520, 240, 3]) correct_image = classifier_metrics.preprocess_image( - image=incorrectly_sized_image) + images=incorrectly_sized_image) _run_with_mock(classifier_metrics.run_inception, array_ops.expand_dims(correct_image, 0)) diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 940b5236276c3e06bf030e310f7453e93c7e3d32..508b4d20d8767f42246a0d0c87f911b7ac612f45 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -38,7 +38,7 @@ def _assert_is_image(data): data.shape[1:].assert_is_fully_defined() -def add_gan_model_image_summaries(gan_model, grid_size=10): +def add_gan_model_image_summaries(gan_model, grid_size=4): """Adds image summaries for real and fake images. Args: diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 3f9d87f54ed678f8424e184435ab8509028ab33f..940762cf2aa0f473cd41d9d543e2773b565a5248 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -86,8 +86,9 @@ def wasserstein_generator_loss( discriminator_gen_outputs: Discriminator output on generated data. Expected to be in the range of (-inf, inf). weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + `discriminator_gen_outputs`, and must be broadcastable to + `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -127,10 +128,12 @@ def wasserstein_discriminator_loss( discriminator_real_outputs: Discriminator output on real data. discriminator_gen_outputs: Discriminator output on generated data. Expected to be in the range of (-inf, inf). - real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the real loss. - generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to - rescale the generated loss. + real_weights: Optional `Tensor` whose rank is either 0, or the same rank as + `discriminator_real_outputs`, and must be broadcastable to + `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). + generated_weights: Same as `real_weights`, but for + `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -167,8 +170,8 @@ def wasserstein_discriminator_loss( # ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` # (https://arxiv.org/abs/1610.09585). def acgan_discriminator_loss( - discriminator_gen_classification_logits, discriminator_real_classification_logits, + discriminator_gen_classification_logits, one_hot_labels, label_smoothing=0.0, real_weights=1.0, @@ -189,18 +192,20 @@ def acgan_discriminator_loss( ACGAN: https://arxiv.org/abs/1610.09585 Args: - discriminator_gen_classification_logits: Classification logits for generated - data. discriminator_real_classification_logits: Classification logits for real data. + discriminator_gen_classification_logits: Classification logits for generated + data. one_hot_labels: A Tensor holding one-hot labels for the batch. label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for "discriminator on real data" as suggested in https://arxiv.org/pdf/1701.00160 - real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the real loss. - generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to - rescale the generated loss. + real_weights: Optional `Tensor` whose rank is either 0, or the same rank as + `discriminator_real_outputs`, and must be broadcastable to + `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). + generated_weights: Same as `real_weights`, but for + `discriminator_gen_classification_logits`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -212,21 +217,25 @@ def acgan_discriminator_loss( Raises: TypeError: If the discriminator does not output a tuple. """ - loss_on_generated = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, - weights=generated_weights, scope=scope, loss_collection=None, - reduction=reduction) - loss_on_real = losses.softmax_cross_entropy( - one_hot_labels, discriminator_real_classification_logits, - weights=real_weights, label_smoothing=label_smoothing, scope=scope, - loss_collection=None, reduction=reduction) - loss = loss_on_generated + loss_on_real - util.add_loss(loss, loss_collection) + with ops.name_scope( + scope, 'acgan_discriminator_loss', + (discriminator_real_classification_logits, + discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss_on_generated = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=generated_weights, scope=scope, loss_collection=None, + reduction=reduction) + loss_on_real = losses.softmax_cross_entropy( + one_hot_labels, discriminator_real_classification_logits, + weights=real_weights, label_smoothing=label_smoothing, scope=scope, + loss_collection=None, reduction=reduction) + loss = loss_on_generated + loss_on_real + util.add_loss(loss, loss_collection) - if add_summaries: - summary.scalar('discriminator_gen_ac_loss', loss_on_generated) - summary.scalar('discriminator_real_ac_loss', loss_on_real) - summary.scalar('discriminator_ac_loss', loss) + if add_summaries: + summary.scalar('discriminator_gen_ac_loss', loss_on_generated) + summary.scalar('discriminator_real_ac_loss', loss_on_real) + summary.scalar('discriminator_ac_loss', loss) return loss @@ -255,8 +264,9 @@ def acgan_generator_loss( data. one_hot_labels: A Tensor holding one-hot labels for the batch. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + `discriminator_gen_classification_logits`, and must be broadcastable to + `discriminator_gen_classification_logits` (i.e., all dimensions must be + either `1`, or the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -269,12 +279,16 @@ def acgan_generator_loss( ValueError: if arg module not either `generator` or `discriminator` TypeError: if the discriminator does not output a tuple. """ - loss = losses.softmax_cross_entropy( - one_hot_labels, discriminator_gen_classification_logits, weights=weights, - scope=scope, loss_collection=loss_collection, reduction=reduction) + with ops.name_scope( + scope, 'acgan_generator_loss', + (discriminator_gen_classification_logits, one_hot_labels)) as scope: + loss = losses.softmax_cross_entropy( + one_hot_labels, discriminator_gen_classification_logits, + weights=weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('generator_ac_loss', loss) + if add_summaries: + summary.scalar('generator_ac_loss', loss) return loss @@ -283,10 +297,9 @@ def acgan_generator_loss( # GANs` (https://arxiv.org/abs/1704.00028). -# TODO(joelshor): Figure out why this function can't be inside a name scope. def wasserstein_gradient_penalty( - generated_data, real_data, + generated_data, generator_inputs, discriminator_fn, discriminator_scope, @@ -302,8 +315,8 @@ def wasserstein_gradient_penalty( (https://arxiv.org/abs/1704.00028) for more details. Args: - generated_data: Output of the generator. real_data: Real data. + generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. discriminator_fn: A discriminator function that conforms to TFGAN API. @@ -311,8 +324,9 @@ def wasserstein_gradient_penalty( epsilon: A small positive number added for numerical stability when computing the gradient norm. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + `real_data` and `generated_data`, and must be broadcastable to + them (i.e., all dimensions must be either `1`, or the same as the + corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -324,46 +338,50 @@ def wasserstein_gradient_penalty( Raises: ValueError: If the rank of data Tensors is unknown. """ - if generated_data.shape.ndims is None: - raise ValueError('`generated_data` can\'t have unknown rank.') - if real_data.shape.ndims is None: - raise ValueError('`real_data` can\'t have unknown rank.') - - differences = generated_data - real_data - batch_size = differences.shape[0].value or array_ops.shape(differences)[0] - alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) - alpha = random_ops.random_uniform(shape=alpha_shape) - interpolates = real_data + (alpha * differences) - - # Reuse variables if a discriminator scope already exists. - reuse = False if discriminator_scope is None else True - with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', - reuse=reuse): - disc_interpolates = discriminator_fn(interpolates, generator_inputs) - - if isinstance(disc_interpolates, tuple): - # ACGAN case: disc outputs more than one tensor - disc_interpolates = disc_interpolates[0] - - gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] - gradient_squares = math_ops.reduce_sum( - math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) - # Propagate shape information, if possible. - if isinstance(batch_size, int): - gradient_squares.set_shape([ - batch_size] + gradient_squares.shape.as_list()[1:]) - # For numerical stability, add epsilon to the sum before taking the square - # root. Note tf.norm does not add epsilon. - slopes = math_ops.sqrt(gradient_squares + epsilon) - penalties = math_ops.square(slopes - 1.0) - penalty = losses.compute_weighted_loss( - penalties, weights, scope=scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'wasserstein_gradient_penalty', + (real_data, generated_data)) as scope: + real_data = ops.convert_to_tensor(real_data) + generated_data = ops.convert_to_tensor(generated_data) + if real_data.shape.ndims is None: + raise ValueError('`real_data` can\'t have unknown rank.') + if generated_data.shape.ndims is None: + raise ValueError('`generated_data` can\'t have unknown rank.') + + differences = generated_data - real_data + batch_size = differences.shape[0].value or array_ops.shape(differences)[0] + alpha_shape = [batch_size] + [1] * (differences.shape.ndims - 1) + alpha = random_ops.random_uniform(shape=alpha_shape) + interpolates = real_data + (alpha * differences) + + with ops.name_scope(None): # Clear scope so update ops are added properly. + # Reuse variables if variables already exists. + with variable_scope.variable_scope(discriminator_scope, 'gpenalty_dscope', + reuse=variable_scope.AUTO_REUSE): + disc_interpolates = discriminator_fn(interpolates, generator_inputs) + + if isinstance(disc_interpolates, tuple): + # ACGAN case: disc outputs more than one tensor + disc_interpolates = disc_interpolates[0] + + gradients = gradients_impl.gradients(disc_interpolates, interpolates)[0] + gradient_squares = math_ops.reduce_sum( + math_ops.square(gradients), axis=list(range(1, gradients.shape.ndims))) + # Propagate shape information, if possible. + if isinstance(batch_size, int): + gradient_squares.set_shape([ + batch_size] + gradient_squares.shape.as_list()[1:]) + # For numerical stability, add epsilon to the sum before taking the square + # root. Note tf.norm does not add epsilon. + slopes = math_ops.sqrt(gradient_squares + epsilon) + penalties = math_ops.square(slopes - 1.0) + penalty = losses.compute_weighted_loss( + penalties, weights, scope=scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('gradient_penalty_loss', penalty) + if add_summaries: + summary.scalar('gradient_penalty_loss', penalty) - return penalty + return penalty # Original losses from `Generative Adversarial Nets` @@ -398,10 +416,11 @@ def minimax_discriminator_loss( label_smoothing: The amount of smoothing for positive labels. This technique is taken from `Improved Techniques for Training GANs` (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the real loss. - generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to - rescale the generated loss. + real_weights: Optional `Tensor` whose rank is either 0, or the same rank as + `real_data`, and must be broadcastable to `real_data` (i.e., all + dimensions must be either `1`, or the same as the corresponding + dimension). + generated_weights: Same as `real_weights`, but for `generated_data`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -460,8 +479,10 @@ def minimax_generator_loss( label_smoothing: The amount of smoothing for positive labels. This technique is taken from `Improved Techniques for Training GANs` (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the loss. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `discriminator_gen_outputs`, and must be broadcastable to + `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -504,10 +525,12 @@ def modified_discriminator_loss( label_smoothing: The amount of smoothing for positive labels. This technique is taken from `Improved Techniques for Training GANs` (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. - real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the real loss. - generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to - rescale the generated loss. + real_weights: Optional `Tensor` whose rank is either 0, or the same rank as + `discriminator_gen_outputs`, and must be broadcastable to + `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). + generated_weights: Same as `real_weights`, but for + `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -532,7 +555,7 @@ def modified_generator_loss( discriminator_gen_outputs, label_smoothing=0.0, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -551,8 +574,9 @@ def modified_generator_loss( is taken from `Improved Techniques for Training GANs` (https://arxiv.org/abs/1606.03498). `0.0` means no smoothing. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + `discriminator_gen_outputs`, and must be broadcastable to `labels` (i.e., + all dimensions must be either `1`, or the same as the corresponding + dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -561,12 +585,15 @@ def modified_generator_loss( Returns: A loss Tensor. The shape depends on `reduction`. """ - loss = losses.sigmoid_cross_entropy( - array_ops.ones_like(discriminator_gen_outputs), discriminator_gen_outputs, - weights, label_smoothing, scope, loss_collection, reduction) + with ops.name_scope(scope, 'generator_modified_loss', + [discriminator_gen_outputs]) as scope: + loss = losses.sigmoid_cross_entropy( + array_ops.ones_like(discriminator_gen_outputs), + discriminator_gen_outputs, weights, label_smoothing, scope, + loss_collection, reduction) - if add_summaries: - summary.scalar('generator_modified_loss', loss) + if add_summaries: + summary.scalar('generator_modified_loss', loss) return loss @@ -598,8 +625,9 @@ def least_squares_generator_loss( real_label: The value that the generator is trying to get the discriminator to output on generated data. weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + `discriminator_gen_outputs`, and must be broadcastable to + `discriminator_gen_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -649,10 +677,12 @@ def least_squares_discriminator_loss( to be in the range of (-inf, inf). real_label: The value that the discriminator tries to output for real data. fake_label: The value that the discriminator tries to output for fake data. - real_weights: A scalar or a `Tensor` of size [batch_size, K] used to rescale - the real loss. - generated_weights: A scalar or a `Tensor` of size [batch_size, K] used to - rescale the generated loss. + real_weights: Optional `Tensor` whose rank is either 0, or the same rank as + `discriminator_real_outputs`, and must be broadcastable to + `discriminator_real_outputs` (i.e., all dimensions must be either `1`, or + the same as the corresponding dimension). + generated_weights: Same as `real_weights`, but for + `discriminator_gen_outputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -721,7 +751,7 @@ def mutual_information_penalty( structured_generator_inputs, predicted_distributions, weights=1.0, - scope='generator_modified_loss', + scope=None, loss_collection=ops.GraphKeys.LOSSES, reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, add_summaries=False): @@ -736,9 +766,8 @@ def mutual_information_penalty( predicted_distributions: A list of tf.Distributions. Predicted by the recognizer, and used to evaluate the likelihood of the structured noise. List length should match `structured_generator_inputs`. - weights: Optional `Tensor` whose rank is either 0, or the same rank as - `labels`, and must be broadcastable to `labels` (i.e., all dimensions must - be either `1`, or the same as the corresponding `losses` dimension). + weights: Optional `Tensor` whose rank is either 0, or the same dimensions as + `structured_generator_inputs`. scope: The scope for the operations performed in computing the loss. loss_collection: collection to which this loss will be added. reduction: A `tf.losses.Reduction` to apply to loss. @@ -750,15 +779,16 @@ def mutual_information_penalty( _validate_information_penalty_inputs( structured_generator_inputs, predicted_distributions) - # Calculate the negative log-likelihood of the reconstructed noise. - log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in - zip(predicted_distributions, structured_generator_inputs)] - loss = -1 * losses.compute_weighted_loss( - log_probs, weights, scope, loss_collection=loss_collection, - reduction=reduction) + with ops.name_scope(scope, 'mutual_information_loss') as scope: + # Calculate the negative log-likelihood of the reconstructed noise. + log_probs = [math_ops.reduce_mean(dist.log_prob(noise)) for dist, noise in + zip(predicted_distributions, structured_generator_inputs)] + loss = -1 * losses.compute_weighted_loss( + log_probs, weights, scope, loss_collection=loss_collection, + reduction=reduction) - if add_summaries: - summary.scalar('mutual_information_penalty', loss) + if add_summaries: + summary.scalar('mutual_information_penalty', loss) return loss diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py index 3e003dd0f808f80dcc486e78e8e101ac6f198947..b5cd8c92ba180e981e0faf877021cb6d69dc34b4 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py @@ -274,8 +274,8 @@ class ACGANLossTest(test.TestCase): self._discriminator_real_classification_logits, 'one_hot_labels': self._one_hot_labels, } - self._generator_loss_name = 'softmax_cross_entropy_loss/value' - self._discriminator_loss_name = 'add' + self._generator_loss_name = 'acgan_generator_loss/value' + self._discriminator_loss_name = 'acgan_discriminator_loss/add' self._expected_g_loss = 3.84974 self._expected_d_loss = 9.43950 @@ -453,10 +453,11 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): 'discriminator_scope': self._scope, } self._expected_loss = 9.00000 - self._expected_op_name = 'weighted_loss/value' + self._expected_op_name = 'wasserstein_gradient_penalty/value' self._batch_size = 1 def _discriminator_fn(self, inputs, _): + ops.add_to_collection('fake_update_ops', constant_op.constant(1.0)) return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs def test_loss_with_placeholder(self): @@ -487,6 +488,26 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest): self.assertEqual( num_vars, len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))) + def test_works_with_get_collection(self): + """Tests that gradient penalty works inside other scopes.""" + # We ran the discriminator once in the setup, so there should be an op + # already in the collection. + self.assertEqual(1, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a name scope. + with ops.name_scope('loss'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(2, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + + # Make sure the op is added to the collection even if it's in a variable + # scope. + with variable_scope.variable_scope('loss_vscope'): + tfgan_losses.wasserstein_gradient_penalty(**self._kwargs) + self.assertEqual(3, len(ops.get_collection( + 'fake_update_ops', self._kwargs['discriminator_scope'].name))) + class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): """Tests for mutual_information_penalty.""" @@ -504,7 +525,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest): 'predicted_distributions': self._predicted_distributions, } self._expected_loss = 1.61610 - self._expected_op_name = 'mul' + self._expected_op_name = 'mutual_information_loss/mul' self._batch_size = 2 diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index fca8063891fe53cb9a384fe6908eb6b1c61b90d7..b341f03a0ddaacca8b036189516c71908bee50eb 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -14,10 +14,41 @@ # ============================================================================== """TFGAN utilities for loss functions that accept GANModel namedtuples. -Example: +The losses and penalties in this file all correspond to losses in +`losses_impl.py`. Losses in that file take individual arguments, whereas in this +file they take a `GANModel` tuple. For example: + +losses_impl.py: + ```python + def wasserstein_discriminator_loss( + discriminator_real_outputs, + discriminator_gen_outputs, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + +tuple_losses_impl.py: + ```python + def wasserstein_discriminator_loss( + gan_model, + real_weights=1.0, + generated_weights=1.0, + scope=None, + loss_collection=ops.GraphKeys.LOSSES, + reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS, + add_summaries=False) + ``` + + + +Example usage: ```python - # `tfgan.losses.args` losses take individual arguments. - w_loss = tfgan.losses.args.wasserstein_discriminator_loss( + # `tfgan.losses.wargs` losses take individual arguments. + w_loss = tfgan.losses.wargs.wasserstein_discriminator_loss( discriminator_real_outputs, discriminator_gen_outputs) diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index a99e3fbec8dc2a07030aa9356be2b05cfb689b8e..48f5e8e47dbcd5d32c23806b967a0d1e7403d2f7 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,7 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN.""" +"""Named tuples for TFGAN. + +TFGAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TFGAN +helper function in `train.py`, or you can manually construct a tuple. +""" from __future__ import absolute_import from __future__ import division @@ -115,7 +120,7 @@ class GANLoss( """GANLoss contains the generator and discriminator losses. Args: - generator_loss: A tensor for the generator loss.. + generator_loss: A tensor for the generator loss. discriminator_loss: A tensor for the discriminator loss. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index cdc4d78e5b235bbbe53cf2717ac48a156fa96845..ad2d5eb86cdab89273efbd4ddce45f6657b54406 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -14,7 +14,17 @@ # ============================================================================== """The TFGAN project provides a lightweight GAN training/testing framework. -See examples in `tensorflow_models` for details on how to use. +This file contains the core helper functions to create and train a GAN model. +See the README or examples in `tensorflow_models` for details on how to use. + +TFGAN training occurs in four steps: +1) Create a model +2) Add a loss +3) Create train ops +4) Run the train ops + +The functions in this file are organized around these four steps. Each function +corresponds to one of the steps. """ from __future__ import absolute_import @@ -48,19 +58,10 @@ __all__ = [ 'get_sequential_train_hooks', 'get_joint_train_hooks', 'get_sequential_train_steps', + 'RunTrainOpsHook', ] -def _convert_tensor_or_l_or_d(tensor_or_l_or_d): - """Convert input, list of inputs, or dictionary of inputs to Tensors.""" - if isinstance(tensor_or_l_or_d, (list, tuple)): - return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] - elif isinstance(tensor_or_l_or_d, dict): - return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} - else: - return ops.convert_to_tensor(tensor_or_l_or_d) - - def gan_model( # Lambdas defining models. generator_fn, @@ -133,20 +134,6 @@ def gan_model( discriminator_fn) -def _validate_distributions(distributions_l, noise_l): - if not isinstance(distributions_l, (tuple, list)): - raise ValueError('`predicted_distributions` must be a list. Instead, found ' - '%s.' % type(distributions_l)) - for dist in distributions_l: - if not isinstance(dist, ds.Distribution): - raise ValueError('Every element in `predicted_distributions` must be a ' - '`tf.Distribution`. Instead, found %s.' % type(dist)) - if len(distributions_l) != len(noise_l): - raise ValueError('Length of `predicted_distributions` %i must be the same ' - 'as the length of structured noise %i.' % - (len(distributions_l), len(noise_l))) - - def infogan_model( # Lambdas defining models. generator_fn, @@ -231,16 +218,6 @@ def infogan_model( predicted_distributions) -def _validate_acgan_discriminator_outputs(discriminator_output): - try: - a, b = discriminator_output - except (TypeError, ValueError): - raise TypeError( - 'A discriminator function for ACGAN must output a tuple ' - 'consisting of (discrimination logits, classification logits).') - return a, b - - def acgan_model( # Lambdas defining models. generator_fn, @@ -252,6 +229,7 @@ def acgan_model( # Optional scopes. generator_scope='Generator', discriminator_scope='Discriminator', + # Options. check_shapes=True): """Returns an ACGANModel contains all the pieces needed for ACGAN training. @@ -497,11 +475,10 @@ def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): def gan_train_ops( - model, # GANModel - loss, # GANLoss + model, + loss, generator_optimizer, discriminator_optimizer, - # Optional check flags. check_for_unused_update_ops=True, # Optional args to pass directly to the `create_train_op`. **kwargs): @@ -801,3 +778,40 @@ def get_sequential_train_steps( return gen_loss + dis_loss, should_stop return sequential_train_steps + + +# Helpers + + +def _convert_tensor_or_l_or_d(tensor_or_l_or_d): + """Convert input, list of inputs, or dictionary of inputs to Tensors.""" + if isinstance(tensor_or_l_or_d, (list, tuple)): + return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] + elif isinstance(tensor_or_l_or_d, dict): + return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} + else: + return ops.convert_to_tensor(tensor_or_l_or_d) + + +def _validate_distributions(distributions_l, noise_l): + if not isinstance(distributions_l, (tuple, list)): + raise ValueError('`predicted_distributions` must be a list. Instead, found ' + '%s.' % type(distributions_l)) + for dist in distributions_l: + if not isinstance(dist, ds.Distribution): + raise ValueError('Every element in `predicted_distributions` must be a ' + '`tf.Distribution`. Instead, found %s.' % type(dist)) + if len(distributions_l) != len(noise_l): + raise ValueError('Length of `predicted_distributions` %i must be the same ' + 'as the length of structured noise %i.' % + (len(distributions_l), len(noise_l))) + + +def _validate_acgan_discriminator_outputs(discriminator_output): + try: + a, b = discriminator_output + except (TypeError, ValueError): + raise TypeError( + 'A discriminator function for ACGAN must output a tuple ' + 'consisting of (discrimination logits, classification logits).') + return a, b diff --git a/tensorflow/contrib/gdr/BUILD b/tensorflow/contrib/gdr/BUILD index bebcf079ba444946bf0377106cbafcbaa7e94e74..bdbe6f0a72621e59562fe113da101ff5a2b8c06d 100644 --- a/tensorflow/contrib/gdr/BUILD +++ b/tensorflow/contrib/gdr/BUILD @@ -85,7 +85,6 @@ tf_cuda_library( "//tensorflow/core/distributed_runtime:rendezvous_mgr_interface", "//tensorflow/core/distributed_runtime:worker", "//tensorflow/core/distributed_runtime:worker_cache", - "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_session", "//tensorflow/core/distributed_runtime/rpc:grpc_call", "//tensorflow/core/distributed_runtime/rpc:grpc_tensor_coding", @@ -104,6 +103,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core/distributed_runtime:base_rendezvous_mgr", + "//tensorflow/core/distributed_runtime:tensor_coding", "//tensorflow/core/distributed_runtime:worker_cache", "//tensorflow/core/distributed_runtime:worker_env", "//tensorflow/core/distributed_runtime:worker_interface", @@ -119,7 +119,6 @@ cc_library( ":gdr_memory_manager", ":gdr_rendezvous_mgr", ":gdr_worker", - "//tensorflow/core:lib_internal", "//tensorflow/core/distributed_runtime/rpc:grpc_server_lib", ], alwayslink = 1, diff --git a/tensorflow/contrib/graph_editor/BUILD b/tensorflow/contrib/graph_editor/BUILD index b4c53d3da655e2f52b5990ac0de3bc7ccc823bcc..967ad2fc090906e93f22c777816eede37f9a1b04 100644 --- a/tensorflow/contrib/graph_editor/BUILD +++ b/tensorflow/contrib/graph_editor/BUILD @@ -144,12 +144,12 @@ py_test( ":graph_editor_py", ":match", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:gradients", "//tensorflow/python:math_ops", + "//tensorflow/python:session", "//tensorflow/python:variables", "//third_party/py/numpy", ], diff --git a/tensorflow/contrib/graph_editor/reroute.py b/tensorflow/contrib/graph_editor/reroute.py index 42968ae63b769f7cea7385933fbadb0782cc86f3..7ffdbb7139281734917fdb715601b317eb58b82f 100644 --- a/tensorflow/contrib/graph_editor/reroute.py +++ b/tensorflow/contrib/graph_editor/reroute.py @@ -397,27 +397,57 @@ def swap_inputs(sgv0, sgv1): def reroute_inputs(sgv0, sgv1): - """Re-route all the inputs of sgv0 to sgv1 (see reroute_inputs).""" + """Re-route all the inputs of two subgraphs. + + Args: + sgv0: the first subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its inputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) def swap_outputs(sgv0, sgv1): - """Swap all the outputs of sgv0 and sgv1 (see _reroute_outputs).""" + """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) def reroute_outputs(sgv0, sgv1): - """Re-route all the outputs of sgv0 to sgv1 (see _reroute_outputs).""" + """Re-route all the outputs of two operations. + + Args: + sgv0: the first subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + sgv1: the second subgraph to have its outputs swapped. This argument is + converted to a subgraph using the same rules than the function + subgraph.make_view. + Returns: + A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. + Note that the function argument sgv0 and sgv1 are also modified in place. + Raises: + StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using + the same rules than the function subgraph.make_view. + """ return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) def swap_ios(sgv0, sgv1): - """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute).""" + """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) def reroute_ios(sgv0, sgv1): - """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute).""" + """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index ab5776b9dd66bb082e9ca3922e8902bfebe6b0b8..ca00394388f67e2ed9508684a47b23c3ee9e79e8 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -191,14 +191,14 @@ class TransformTest(test.TestCase): # Extract the operations. replacement_ts = {w.value(): g} original_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("grad/mul1_grad/mul_1")) + get_operation_by_name("grad/mul1_grad/Mul_1")) # Should not raise exception. res = ge.graph_replace(g, replacement_ts, dst_scope="res") # Extract the operations after graph_replace. result_mul1_grad = (ops.get_default_graph(). - get_operation_by_name("res/grad/mul1_grad/mul_1")) + get_operation_by_name("res/grad/mul1_grad/Mul_1")) # Make sure _original_ops are as expected. self.assertEquals(original_mul1_grad._original_op.name, u"mul1") diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 959905e9826fe439112078a32fef9a5f5b96e9ac..30bc33b9ee42ba78bc7307c67c0fc0af9f3356ef 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -93,6 +93,8 @@ class ListView(object): # TODO(fkp): very generic code, it should be moved in a more generic place. def is_iterable(obj): """Return true if the object is iterable.""" + if isinstance(obj, tf_ops.Tensor): + return False try: _ = iter(obj) except Exception: # pylint: disable=broad-except diff --git a/tensorflow/contrib/grid_rnn/BUILD b/tensorflow/contrib/grid_rnn/BUILD index 7fbb9f024c589895aa2dff7b6f5d8ba8c399af48..d601a1ec6f7a219bcd461d819ab2dfc64135a3ae 100644 --- a/tensorflow/contrib/grid_rnn/BUILD +++ b/tensorflow/contrib/grid_rnn/BUILD @@ -31,14 +31,12 @@ cuda_py_tests( additional_deps = [ ":grid_rnn_py", "//third_party/py/numpy", - "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", "//tensorflow/python:init_ops", "//tensorflow/python:nn_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python:rnn", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/hooks/BUILD b/tensorflow/contrib/hooks/BUILD index d81e868d4a922698e4755733b999112088fa2a0b..1b528d7afc1112f5dc0667ae299ade02bc8fd04b 100644 --- a/tensorflow/contrib/hooks/BUILD +++ b/tensorflow/contrib/hooks/BUILD @@ -19,30 +19,11 @@ py_library( ], srcs_version = "PY2AND3", deps = [ - "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", - "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python:util", ], ) -py_test( - name = "profiler_hook_test", - size = "small", - srcs = ["python/training/profiler_hook_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":hooks", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:training", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook.py b/tensorflow/contrib/hooks/python/training/profiler_hook.py index 35aa25edfde6f2ed7051ed75ff4f53f8732ae76e..6173aa0797138730e79b21bc9a1779d346edab6b 100644 --- a/tensorflow/contrib/hooks/python/training/profiler_hook.py +++ b/tensorflow/contrib/hooks/python/training/profiler_hook.py @@ -12,93 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Additional `SessionRunHook` implementations to complement those in -tensorflow/python/training. - -""" +"""Placeholder of ProfilerHook for backward compatibility.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os.path - -from tensorflow.core.protobuf import config_pb2 -from tensorflow.python.client import timeline -from tensorflow.python.platform import gfile -from tensorflow.python.platform import tf_logging as logging -from tensorflow.python.training.basic_session_run_hooks import SecondOrStepTimer -from tensorflow.python.training.session_run_hook import SessionRunArgs -from tensorflow.python.training import session_run_hook -from tensorflow.python.training import training_util - - -class ProfilerHook(session_run_hook.SessionRunHook): - """Captures CPU/GPU profiling information every N steps or seconds. - - This produces files called "timeline-.json", which are in Chrome - Trace format. - - For more information see: - https://github.com/catapult-project/catapult/blob/master/tracing/README.md""" - - def __init__(self, - save_steps=None, - save_secs=None, - output_dir="", - show_dataflow=True, - show_memory=False): - """Initializes a hook that takes periodic profiling snapshots. - - Args: - save_steps: `int`, save profile traces every N steps. Exactly one of - `save_secs` and `save_steps` should be set. - save_secs: `int`, save profile traces every N seconds. - output_dir: `string`, the directory to save the profile traces to. - Defaults to the current directory. - show_dataflow: `bool`, if True, add flow events to the trace connecting - producers and consumers of tensors. - show_memory: `bool`, if True, add object snapshot events to the trace - showing the sizes and lifetimes of tensors. - """ - self._output_file = os.path.join(output_dir, "timeline-{}.json") - self._show_dataflow = show_dataflow - self._show_memory = show_memory - self._timer = SecondOrStepTimer(every_secs=save_secs, - every_steps=save_steps) - - def begin(self): - self._next_step = None - self._global_step_tensor = training_util.get_global_step() - if self._global_step_tensor is None: - raise RuntimeError( - "Global step should be created to use ProfilerHook.") - - def before_run(self, run_context): - self._request_summary = ( - self._next_step is None or - self._timer.should_trigger_for_step(self._next_step)) - requests = {"global_step": self._global_step_tensor} - opts = (config_pb2.RunOptions(trace_level=config_pb2.RunOptions.FULL_TRACE) - if self._request_summary else None) - - return SessionRunArgs(requests, options=opts) - - def after_run(self, run_context, run_values): - global_step = run_values.results["global_step"] - - if self._request_summary: - self._timer.update_last_triggered_step(global_step) - self._save(global_step, - self._output_file.format(global_step), - run_values.run_metadata.step_stats) - - self._next_step = global_step + 1 +from tensorflow.python.training import basic_session_run_hooks - def _save(self, step, save_path, step_stats): - logging.info("Saving timeline for %d into '%s'.", step, save_path) - with gfile.Open(save_path, "w") as f: - trace = timeline.Timeline(step_stats) - f.write(trace.generate_chrome_trace_format( - show_dataflow=self._show_dataflow, - show_memory=self._show_memory)) +ProfilerHook = basic_session_run_hooks.ProfilerHook # pylint: disable=invalid-name diff --git a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py b/tensorflow/contrib/hooks/python/training/profiler_hook_test.py deleted file mode 100644 index e7ecb5eb2fcc56f14f3d5babe2c22652159afd76..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/hooks/python/training/profiler_hook_test.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for profiler_hook.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os.path -import shutil -import tempfile - -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.hooks.python.training import ProfilerHook -from tensorflow.python.framework import ops -from tensorflow.python.ops import state_ops -from tensorflow.python.platform import gfile -from tensorflow.python.platform import test -from tensorflow.python.training import monitored_session - - -class ProfilerHookTest(test.TestCase): - - def setUp(self): - super(ProfilerHookTest, self).setUp() - self.output_dir = tempfile.mkdtemp() - self.graph = ops.Graph() - self.filepattern = os.path.join(self.output_dir, "timeline-*.json") - with self.graph.as_default(): - self.global_step = variables.get_or_create_global_step() - self.train_op = state_ops.assign_add(self.global_step, 1) - - def tearDown(self): - super(ProfilerHookTest, self).tearDown() - shutil.rmtree(self.output_dir, ignore_errors=True) - - def _count_timeline_files(self): - return len(gfile.Glob(self.filepattern)) - - def test_raise_in_both_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=10, save_steps=20) - - def test_raise_in_none_secs_and_steps(self): - with self.assertRaises(ValueError): - ProfilerHook(save_secs=None, save_steps=None) - - def test_save_secs_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) - self.assertEqual(1, self._count_timeline_files()) - - @test.mock.patch('time.time') - def test_save_secs_saves_periodically(self, mock_time): - # Pick a fixed start time. - current_time = 1484863632.320497 - - with self.graph.as_default(): - mock_time.return_value = current_time - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - # Simulate 2.5 seconds of sleep. - mock_time.return_value = current_time + 2.5 - sess.run(self.train_op) # Saved. - - # Pretend some small amount of time has passed. - mock_time.return_value = current_time + 0.1 - sess.run(self.train_op) # Not saved. - # Edge test just before we should save the timeline. - mock_time.return_value = current_time + 1.9 - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - - mock_time.return_value = current_time + 4.5 - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - def test_save_steps_saves_in_first_step(self): - with self.graph.as_default(): - hook = ProfilerHook(save_secs=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - sess.run(self.train_op) # Saved. - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - - def test_save_steps_saves_periodically(self): - with self.graph.as_default(): - hook = ProfilerHook(save_steps=2, output_dir=self.output_dir) - with monitored_session.SingularMonitoredSession(hooks=[hook]) as sess: - self.assertEqual(0, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(1, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Not saved. - self.assertEqual(2, self._count_timeline_files()) - sess.run(self.train_op) # Saved. - self.assertEqual(3, self._count_timeline_files()) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD index 8c92e33bdf01a5aec33892fe140da5f762f05679..324035100df366b80f57af9052c4bd935655b248 100644 --- a/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD +++ b/tensorflow/contrib/hvx/clock_cycle_profiling/BUILD @@ -52,13 +52,9 @@ tf_cc_binary( "//tensorflow/core:android_tensorflow_test_lib", ], "//conditions:default": [ - "//tensorflow/core:core_cpu", "//tensorflow/core:lib", - "//tensorflow/core:framework", "//tensorflow/core:framework_internal", - "//tensorflow/core:protos_all_cc", "//tensorflow/core:tensorflow", - "//tensorflow/core:test", ], }), ) diff --git a/tensorflow/contrib/image/BUILD b/tensorflow/contrib/image/BUILD index a18f14112e469b1cf83a046fa65b87e5c69fb88b..157e97d237021d95c935a6be66aa57842b97125c 100755 --- a/tensorflow/contrib/image/BUILD +++ b/tensorflow/contrib/image/BUILD @@ -75,11 +75,13 @@ tf_custom_op_py_library( ":image_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:common_shapes", "//tensorflow/python:constant_op", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) @@ -143,12 +145,13 @@ py_library( srcs_version = "PY2AND3", deps = [ ":distort_image_ops", + ":single_image_random_dot_stereograms_py", "//tensorflow/contrib/util:util_py", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:image_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:util", ], ) @@ -211,6 +214,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":image_py", + ":single_image_random_dot_stereograms_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", diff --git a/tensorflow/contrib/image/__init__.py b/tensorflow/contrib/image/__init__.py index 59a322d3ca6e7e53872f8e7e126e30923ddd77a0..d030dffadeb9d67f7ffcbc197a2a3feb9b3b122d 100755 --- a/tensorflow/contrib/image/__init__.py +++ b/tensorflow/contrib/image/__init__.py @@ -26,6 +26,8 @@ projective transforms (including rotation) are supported. @@random_yiq_hsv @@rotate @@transform +@@translate +@@translations_to_projective_transforms @@bipartite_match @@single_image_random_dot_stereograms """ @@ -41,6 +43,8 @@ from tensorflow.contrib.image.python.ops.image_ops import angles_to_projective_t from tensorflow.contrib.image.python.ops.image_ops import compose_transforms from tensorflow.contrib.image.python.ops.image_ops import rotate from tensorflow.contrib.image.python.ops.image_ops import transform +from tensorflow.contrib.image.python.ops.image_ops import translate +from tensorflow.contrib.image.python.ops.image_ops import translations_to_projective_transforms from tensorflow.contrib.image.python.ops.single_image_random_dot_stereograms import single_image_random_dot_stereograms from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py index b8a0706b61449ebebeb2f1dc98b438f9dd620aa3..b50177ae5651fbc15f292e11031411c2074357ec 100644 --- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py +++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py @@ -36,8 +36,8 @@ _DTYPES = set( class ImageOpsTest(test_util.TensorFlowTestCase): def test_zeros(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): for shape in [(5, 5), (24, 24), (2, 24, 24, 3)]: for angle in [0, 1, np.pi / 2.0]: image = array_ops.zeros(shape, dtype) @@ -46,8 +46,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): np.zeros(shape, dtype.as_numpy_dtype())) def test_rotate_even(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(36), dtype), (6, 6)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -68,8 +68,8 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [1, 7, 13, 19, 25, 31], [0, 6, 12, 18, 24, 30]]]) def test_rotate_odd(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = array_ops.reshape( math_ops.cast(math_ops.range(25), dtype), (5, 5)) image_rep = array_ops.tile(image[None, :, :, None], [3, 1, 1, 1]) @@ -87,9 +87,25 @@ class ImageOpsTest(test_util.TensorFlowTestCase): [22, 17, 12, 7, 2], [23, 18, 13, 8, 3], [24, 19, 14, 9, 4]]]) + def test_translate(self): + for dtype in _DTYPES: + with self.test_session(): + image = constant_op.constant( + [[1, 0, 1, 0], + [0, 1, 0, 1], + [1, 0, 1, 0], + [0, 1, 0, 1]], dtype=dtype) + translation = constant_op.constant([-1, -1], dtypes.float32) + image_translated = image_ops.translate(image, translation) + self.assertAllEqual(image_translated.eval(), + [[1, 0, 1, 0], + [0, 1, 0, 0], + [1, 0, 1, 0], + [0, 0, 0, 0]]) + def test_compose(self): - with self.test_session(): - for dtype in _DTYPES: + for dtype in _DTYPES: + with self.test_session(): image = constant_op.constant( [[1, 1, 1, 0], [1, 0, 0, 0], @@ -246,4 +262,3 @@ class BipartiteMatchTest(test_util.TensorFlowTestCase): if __name__ == "__main__": googletest.main() - diff --git a/tensorflow/contrib/image/python/ops/distort_image_ops.py b/tensorflow/contrib/image/python/ops/distort_image_ops.py index 39f023a2b40a1a8481217fe8fa191a5072e7a3ff..06e8e4ee720d04f4b29a25f833297bb17a7d239c 100644 --- a/tensorflow/contrib/image/python/ops/distort_image_ops.py +++ b/tensorflow/contrib/image/python/ops/distort_image_ops.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_distort_image_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -132,7 +133,7 @@ def adjust_hsv_in_yiq(image, orig_dtype = image.dtype flt_image = image_ops.convert_image_dtype(image, dtypes.float32) - rgb_altered = _distort_image_ops.adjust_hsv_in_yiq( + rgb_altered = gen_distort_image_ops.adjust_hsv_in_yiq( flt_image, delta_hue, scale_saturation, scale_value) return image_ops.convert_image_dtype(rgb_altered, orig_dtype) diff --git a/tensorflow/contrib/image/python/ops/image_ops.py b/tensorflow/contrib/image/python/ops/image_ops.py index aef3e385b57486d5cb3cb13d9e8b9519768abd7c..011ddeaa9a1eebaa507c9e0d33f9546ff3497166 100644 --- a/tensorflow/contrib/image/python/ops/image_ops.py +++ b/tensorflow/contrib/image/python/ops/image_ops.py @@ -37,16 +37,18 @@ _IMAGE_DTYPES = set( ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) -def rotate(images, angles, interpolation="NEAREST"): +def rotate(images, angles, interpolation="NEAREST", name=None): """Rotate image(s) by the passed angle(s) in radians. Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. angles: A scalar angle to rotate all images by, or (if images has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. Returns: Image(s) with the same type and shape as `images`, rotated by the given @@ -55,38 +57,77 @@ def rotate(images, angles, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - image_height = math_ops.cast(array_ops.shape(images)[1], dtypes.float32)[None] - image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None] - output = transform( - images, - angles_to_projective_transforms(angles, image_height, image_width), - interpolation=interpolation) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "rotate"): + image_or_images = ops.convert_to_tensor(images) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + image_height = math_ops.cast(array_ops.shape(images)[1], + dtypes.float32)[None] + image_width = math_ops.cast(array_ops.shape(images)[2], + dtypes.float32)[None] + output = transform( + images, + angles_to_projective_transforms(angles, image_height, image_width), + interpolation=interpolation) + if image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output + + +def translate(images, translations, interpolation="NEAREST", name=None): + """Translate image(s) by the passed vectors(s). + Args: + images: A tensor of shape (num_images, num_rows, num_columns, num_channels) + (NHWC), (num_rows, num_columns, num_channels) (HWC), or + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. + translations: A vector representing [dx, dy] or (if images has rank 4) + a matrix of length num_images, with a [dx, dy] vector for each image in + the batch. + interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". + name: The name of the op. -def angles_to_projective_transforms(angles, image_height, image_width): + Returns: + Image(s) with the same type and shape as `images`, translated by the given + vector(s). Empty space due to the translation will be filled with zeros. + + Raises: + TypeError: If `image` is an invalid type. + """ + with ops.name_scope(name, "translate"): + return transform( + images, + translations_to_projective_transforms(translations), + interpolation=interpolation) + + +def angles_to_projective_transforms(angles, + image_height, + image_width, + name=None): """Returns projective transform(s) for the given angle(s). Args: angles: A scalar angle to rotate all images by, or (for batches of images) - a vector with an angle to rotate each image in the batch. + a vector with an angle to rotate each image in the batch. The rank must + be statically known (the shape is not `TensorShape(None)`. image_height: Height of the image(s) to be transformed. image_width: Width of the image(s) to be transformed. @@ -94,41 +135,89 @@ def angles_to_projective_transforms(angles, image_height, image_width): A tensor of shape (num_images, 8). Projective transforms which can be given to `tf.contrib.image.transform`. """ - angle_or_angles = ops.convert_to_tensor( - angles, name="angles", dtype=dtypes.float32) - if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test - angles = angle_or_angles[None] - elif len(angle_or_angles.get_shape()) == 1: - angles = angle_or_angles - else: - raise TypeError("Angles should have rank 0 or 1.") - x_offset = ((image_width - 1) - (math_ops.cos(angles) * - (image_width - 1) - math_ops.sin(angles) * - (image_height - 1))) / 2.0 - y_offset = ((image_height - 1) - (math_ops.sin(angles) * - (image_width - 1) + math_ops.cos(angles) * - (image_height - 1))) / 2.0 - num_angles = array_ops.shape(angles)[0] - return array_ops.concat( - values=[ - math_ops.cos(angles)[:, None], - -math_ops.sin(angles)[:, None], - x_offset[:, None], - math_ops.sin(angles)[:, None], - math_ops.cos(angles)[:, None], - y_offset[:, None], - array_ops.zeros((num_angles, 2), dtypes.float32), - ], - axis=1) - - -def transform(images, transforms, interpolation="NEAREST"): + with ops.name_scope(name, "angles_to_projective_transforms"): + angle_or_angles = ops.convert_to_tensor( + angles, name="angles", dtype=dtypes.float32) + if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test + angles = angle_or_angles[None] + elif len(angle_or_angles.get_shape()) == 1: + angles = angle_or_angles + else: + raise TypeError("Angles should have rank 0 or 1.") + x_offset = ((image_width - 1) - (math_ops.cos(angles) * + (image_width - 1) - math_ops.sin(angles) * + (image_height - 1))) / 2.0 + y_offset = ((image_height - 1) - (math_ops.sin(angles) * + (image_width - 1) + math_ops.cos(angles) * + (image_height - 1))) / 2.0 + num_angles = array_ops.shape(angles)[0] + return array_ops.concat( + values=[ + math_ops.cos(angles)[:, None], + -math_ops.sin(angles)[:, None], + x_offset[:, None], + math_ops.sin(angles)[:, None], + math_ops.cos(angles)[:, None], + y_offset[:, None], + array_ops.zeros((num_angles, 2), dtypes.float32), + ], + axis=1) + + +def translations_to_projective_transforms(translations, name=None): + """Returns projective transform(s) for the given translation(s). + + Args: + translations: A 2-element list representing [dx, dy] or a matrix of + 2-element lists representing [dx, dy] to translate for each image + (for a batch of images). The rank must be statically known (the shape + is not `TensorShape(None)`. + name: The name of the op. + + Returns: + A tensor of shape (num_images, 8) projective transforms which can be given + to `tf.contrib.image.transform`. + """ + with ops.name_scope(name, "translations_to_projective_transforms"): + translation_or_translations = ops.convert_to_tensor( + translations, name="translations", dtype=dtypes.float32) + if translation_or_translations.get_shape().ndims is None: + raise TypeError( + "translation_or_translations rank must be statically known") + elif len(translation_or_translations.get_shape()) == 1: + translations = translation_or_translations[None] + elif len(translation_or_translations.get_shape()) == 2: + translations = translation_or_translations + else: + raise TypeError("Translations should have rank 1 or 2.") + num_translations = array_ops.shape(translations)[0] + # The translation matrix looks like: + # [[1 0 -dx] + # [0 1 -dy] + # [0 0 1]] + # where the last entry is implicit. + # Translation matrices are always float32. + return array_ops.concat( + values=[ + array_ops.ones((num_translations, 1), dtypes.float32), + array_ops.zeros((num_translations, 1), dtypes.float32), + -translations[:, 0, None], + array_ops.zeros((num_translations, 1), dtypes.float32), + array_ops.ones((num_translations, 1), dtypes.float32), + -translations[:, 1, None], + array_ops.zeros((num_translations, 2), dtypes.float32), + ], + axis=1) + + +def transform(images, transforms, interpolation="NEAREST", name=None): """Applies the given transform(s) to the image(s). Args: images: A tensor of shape (num_images, num_rows, num_columns, num_channels) (NHWC), (num_rows, num_columns, num_channels) (HWC), or - (num_rows, num_columns) (HW). + (num_rows, num_columns) (HW). The rank must be statically known (the + shape is not `TensorShape(None)`. transforms: Projective transform matrix/matrices. A vector of length 8 or tensor of size N x 8. If one row of transforms is [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point @@ -146,34 +235,40 @@ def transform(images, transforms, interpolation="NEAREST"): Raises: TypeError: If `image` is an invalid type. """ - image_or_images = ops.convert_to_tensor(images, name="images") - transform_or_transforms = ops.convert_to_tensor( - transforms, name="transforms", dtype=dtypes.float32) - if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: - raise TypeError("Invalid dtype %s." % image_or_images.dtype) - if len(image_or_images.get_shape()) == 2: - images = image_or_images[None, :, :, None] - elif len(image_or_images.get_shape()) == 3: - images = image_or_images[None, :, :, :] - elif len(image_or_images.get_shape()) == 4: - images = image_or_images - else: - raise TypeError("Images should have rank between 2 and 4.") - - if len(transform_or_transforms.get_shape()) == 1: - transforms = transform_or_transforms[None] - elif len(transform_or_transforms.get_shape()) == 2: - transforms = transform_or_transforms - else: - raise TypeError("Transforms should have rank 1 or 2.") - output = gen_image_ops.image_projective_transform( - images, transforms, interpolation=interpolation.upper()) - if len(image_or_images.get_shape()) == 2: - return output[0, :, :, 0] - elif len(image_or_images.get_shape()) == 3: - return output[0, :, :, :] - else: - return output + with ops.name_scope(name, "transform"): + image_or_images = ops.convert_to_tensor(images, name="images") + transform_or_transforms = ops.convert_to_tensor( + transforms, name="transforms", dtype=dtypes.float32) + if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: + raise TypeError("Invalid dtype %s." % image_or_images.dtype) + elif image_or_images.get_shape().ndims is None: + raise TypeError("image_or_images rank must be statically known") + elif len(image_or_images.get_shape()) == 2: + images = image_or_images[None, :, :, None] + elif len(image_or_images.get_shape()) == 3: + images = image_or_images[None, :, :, :] + elif len(image_or_images.get_shape()) == 4: + images = image_or_images + else: + raise TypeError("Images should have rank between 2 and 4.") + + if len(transform_or_transforms.get_shape()) == 1: + transforms = transform_or_transforms[None] + elif transform_or_transforms.get_shape().ndims is None: + raise TypeError( + "transform_or_transforms rank must be statically known") + elif len(transform_or_transforms.get_shape()) == 2: + transforms = transform_or_transforms + else: + raise TypeError("Transforms should have rank 1 or 2.") + output = gen_image_ops.image_projective_transform( + images, transforms, interpolation=interpolation.upper()) + if len(image_or_images.get_shape()) == 2: + return output[0, :, :, 0] + elif len(image_or_images.get_shape()) == 3: + return output[0, :, :, :] + else: + return output def compose_transforms(*transforms): @@ -191,11 +286,12 @@ def compose_transforms(*transforms): order. """ assert transforms, "transforms cannot be empty" - composed = _flat_transforms_to_matrices(transforms[0]) - for tr in transforms[1:]: - # Multiply batches of matrices. - composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) - return _transform_matrices_to_flat(composed) + with ops.name_scope("compose_transforms"): + composed = _flat_transforms_to_matrices(transforms[0]) + for tr in transforms[1:]: + # Multiply batches of matrices. + composed = math_ops.matmul(composed, _flat_transforms_to_matrices(tr)) + return _transform_matrices_to_flat(composed) def _flat_transforms_to_matrices(transforms): @@ -211,8 +307,8 @@ def _flat_transforms_to_matrices(transforms): def _transform_matrices_to_flat(transform_matrices): # Flatten each matrix. - transforms = array_ops.reshape( - transform_matrices, constant_op.constant([-1, 9])) + transforms = array_ops.reshape(transform_matrices, + constant_op.constant([-1, 9])) # Divide each matrix by the last entry (normally 1). transforms /= transforms[:, 8:9] return transforms[:, :8] @@ -260,10 +356,10 @@ def _image_projective_transform_grad(op, grad): return [output, None] -def bipartite_match( - distance_mat, - num_valid_rows, - top_k=-1): +def bipartite_match(distance_mat, + num_valid_rows, + top_k=-1, + name="bipartite_match"): """Find bipartite matching based on a given distance matrix. A greedy bi-partite matching algorithm is used to obtain the matching with @@ -282,6 +378,7 @@ def bipartite_match( top_k: A scalar that specifies the number of top-k matches to retrieve. If set to be negative, then is set according to the maximum number of matches from `distance_mat`. + name: The name of the op. Returns: row_to_col_match_indices: A vector of length num_rows, which is the number @@ -292,7 +389,8 @@ def bipartite_match( If `col_to_row_match_indices[j]` is not -1, column j is matched to row `col_to_row_match_indices[j]`. """ - result = gen_image_ops.bipartite_match(distance_mat, num_valid_rows, top_k) + result = gen_image_ops.bipartite_match( + distance_mat, num_valid_rows, top_k, name=name) return result diff --git a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py index 79261c5e7501566537ee9492b5aa64570599e862..5cccf26028ca6bf269dbc67a33075351edecb407 100755 --- a/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py +++ b/tensorflow/contrib/image/python/ops/single_image_random_dot_stereograms.py @@ -17,6 +17,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.image.ops import gen_single_image_random_dot_stereograms_ops from tensorflow.contrib.util import loader from tensorflow.python.framework import ops from tensorflow.python.platform import resource_loader @@ -107,7 +108,7 @@ def single_image_random_dot_stereograms( 'depth_values' """ - result = _sirds_ops.single_image_random_dot_stereograms( + result = gen_single_image_random_dot_stereograms_ops.single_image_random_dot_stereograms( # pylint: disable=line-too-long depth_values=depth_values, hidden_surface_removal=hidden_surface_removal, convergence_dots_size=convergence_dots_size, diff --git a/tensorflow/contrib/input_pipeline/BUILD b/tensorflow/contrib/input_pipeline/BUILD index bb7857eb998beb89517985a401d5b7afe483d843..9d6b4d5d87e24d72b29ab33ee805fe0d068cc30a 100644 --- a/tensorflow/contrib/input_pipeline/BUILD +++ b/tensorflow/contrib/input_pipeline/BUILD @@ -67,9 +67,9 @@ tf_custom_op_py_library( "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:state_ops", + "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], diff --git a/tensorflow/contrib/kernel_methods/BUILD b/tensorflow/contrib/kernel_methods/BUILD index ae1402b0e6688a0f43278999d1d93282ea2a11a5..a2f320ab11291e4049c8367e1f133a4fbcb72a62 100644 --- a/tensorflow/contrib/kernel_methods/BUILD +++ b/tensorflow/contrib/kernel_methods/BUILD @@ -64,6 +64,7 @@ py_test( name = "kernel_estimators_test", srcs = ["python/kernel_estimators_test.py"], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":kernel_methods", "//tensorflow/contrib/layers:layers_py", diff --git a/tensorflow/contrib/kfac/BUILD b/tensorflow/contrib/kfac/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..9a5759bf14f753bbc50d3ef8f54ceab7daf745ab --- /dev/null +++ b/tensorflow/contrib/kfac/BUILD @@ -0,0 +1,38 @@ +# Description: +# Contains KfacOptimizer, an implementation of the K-FAC optimization +# algorithm in TensorFlow. +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "kfac", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:curvature_matrix_vector_products_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_blocks_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_estimator_lib", + "//tensorflow/contrib/kfac/python/ops:fisher_factors_lib", + "//tensorflow/contrib/kfac/python/ops:kfac_optimizer_lib", + "//tensorflow/contrib/kfac/python/ops:layer_collection_lib", + "//tensorflow/contrib/kfac/python/ops:loss_functions_lib", + "//tensorflow/contrib/kfac/python/ops:op_queue_lib", + "//tensorflow/contrib/kfac/python/ops:utils_lib", + "//tensorflow/python:util", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kfac/README.md b/tensorflow/contrib/kfac/README.md new file mode 100644 index 0000000000000000000000000000000000000000..762a2f0b57e95e2fef3dd177070701afb410e93a --- /dev/null +++ b/tensorflow/contrib/kfac/README.md @@ -0,0 +1,89 @@ +# K-FAC: Kronecker-Factored Approximate Curvature + +**K-FAC in TensorFlow** is an implementation of [K-FAC][kfac-paper], an +approximate second-order optimization method, in TensorFlow. When applied to +feedforward and convolutional neural networks, K-FAC can converge `>3.5x` +faster in `>14x` fewer iterations than SGD with Momentum. + +[kfac-paper]: https://arxiv.org/abs/1503.05671 + +## What is K-FAC? + +K-FAC, short for "Kronecker-factored Approximate Curvature", is an approximation +to the [Natural Gradient][natural_gradient] algorithm designed specifically for +neural networks. It maintains a block-diagonal approximation to the [Fisher +Information matrix][fisher_information], whose inverse preconditions the +gradient. + +K-FAC can be used in place of SGD, Adam, and other `Optimizer` implementations. +Experimentally, K-FAC converges `>3.5x` faster than well-tuned SGD. + +Unlike most optimizers, K-FAC exploits structure in the model itself (e.g. "What +are the weights for layer i?"). As such, you must add some additional code while +constructing your model to use K-FAC. + +[natural_gradient]: http://www.mitpressjournals.org/doi/abs/10.1162/089976698300017746 +[fisher_information]: https://en.wikipedia.org/wiki/Fisher_information#Matrix_form + +## Why should I use K-FAC? + +K-FAC can take advantage of the curvature of the optimization problem, resulting +in **faster training**. For an 8-layer Autoencoder, K-FAC converges to the same +loss as SGD with Momentum in 3.8x fewer seconds and 14.7x fewer updates. See how +training loss changes as a function of number of epochs, steps, and seconds: + +![autoencoder](g3doc/autoencoder.png) + +## Is K-FAC for me? + +If you have a feedforward or convolutional model for classification that is +converging too slowly, K-FAC is for you. K-FAC can be used in your model if: + +* Your model defines a posterior distribution. +* Your model uses only fully-connected or convolutional layers (residual + connections OK). +* You are training on CPU or GPU. +* You can modify model code to register layers with K-FAC. + +## How do I use K-FAC? + +Using K-FAC requires three steps: + +1. Registering layer inputs, weights, and pre-activations with a + `LayerCollection`. +1. Minimizing the loss with a `KfacOptimizer`. +1. Keeping K-FAC's preconditioner updated. + +```python +# Build model. +w = tf.get_variable("w", ...) +b = tf.get_variable("b", ...) +logits = tf.matmul(x, w) + b +loss = tf.reduce_mean( + tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=logits)) + +# Register layers. +layer_collection = LayerCollection() +layer_collection.register_fully_connected((w, b), x, logits) +layer_collection.register_categorical_predictive_distribution(logits) + +# Construct training ops. +optimizer = KfacOptimizer(..., layer_collection=layer_collection) +train_op = optimizer.minimize(loss) + +# Minimize loss. +with tf.Session() as sess: + ... + sess.run([train_op, optimizer.cov_update_op, optimizer.inv_update_op]) +``` + +See [`examples/`](https://www.tensorflow.org/code/tensorflow/contrib/kfac/examples/) for runnable, end-to-end illustrations. + +## Authors + +- Alok Aggarwal +- Daniel Duckworth +- James Martens +- Matthew Johnson +- Olga Wichrowska +- Roger Grosse diff --git a/tensorflow/contrib/kfac/__init__.py b/tensorflow/contrib/kfac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1ea354e6cdf3e78eaca1f3e5dff174ed489c752e --- /dev/null +++ b/tensorflow/contrib/kfac/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Kronecker-factored Approximate Curvature Optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long +from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products_lib as curvature_matrix_vector_products +from tensorflow.contrib.kfac.python.ops import estimator_lib as estimator +from tensorflow.contrib.kfac.python.ops import fisher_blocks_lib as fisher_blocks +from tensorflow.contrib.kfac.python.ops import fisher_factors_lib as fisher_factors +from tensorflow.contrib.kfac.python.ops import layer_collection_lib as layer_collection +from tensorflow.contrib.kfac.python.ops import loss_functions_lib as loss_functions +from tensorflow.contrib.kfac.python.ops import op_queue_lib as op_queue +from tensorflow.contrib.kfac.python.ops import optimizer_lib as optimizer +from tensorflow.contrib.kfac.python.ops import utils_lib as utils +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long + +_allowed_symbols = [ + "curvature_matrix_vector_products", + "estimator", + "fisher_blocks", + "fisher_factors", + "layer_collection", + "loss_functions", + "op_queue", + "optimizer", + "utils", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/examples/BUILD b/tensorflow/contrib/kfac/examples/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..89965eda374b2b403f680fc77eb923d0e660d1e2 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/BUILD @@ -0,0 +1,72 @@ +package(default_visibility = [ + "//learning/brain/contrib/kfac/examples:__subpackages__", + "//tensorflow/contrib/kfac/examples:__subpackages__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "mlp_mnist_main", + srcs = ["mlp_mnist_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":mlp", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "mlp", + srcs = ["mlp.py"], + srcs_version = "PY2AND3", + deps = [ + ":mnist", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "convnet_mnist_main", + srcs = ["convnet_mnist_main.py"], + srcs_version = "PY2AND3", + deps = [ + ":convnet", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "convnet", + srcs = ["convnet.py"], + srcs_version = "PY2AND3", + deps = [ + ":mlp", + ":mnist", + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +py_library( + name = "mnist", + srcs = ["mnist.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kfac/examples/convnet.py b/tensorflow/contrib/kfac/examples/convnet.py new file mode 100644 index 0000000000000000000000000000000000000000..558bc294bc8ac129b3055ed46623c78a0d5a33e3 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet.py @@ -0,0 +1,457 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +This library fits a 5-layer ConvNet on MNIST using K-FAC. The model has the +following structure, + +- Conv Layer: 5x5 kernel, 16 output channels. +- Max Pool: 3x3 kernel, stride 2. +- Conv Layer: 5x5 kernel, 16 output channels. +- Max Pool: 3x3 kernel, stride 2. +- Linear: 10 output dims. + +After 3k~6k steps, this should reach perfect accuracy on the training set. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp +from tensorflow.contrib.kfac.examples import mnist + +lc = tf.contrib.kfac.layer_collection +oq = tf.contrib.kfac.op_queue +opt = tf.contrib.kfac.optimizer + +__all__ = [ + "conv_layer", + "max_pool_layer", + "linear_layer", + "build_model", + "minimize_loss_single_machine", + "minimize_loss_distributed", + "train_mnist_single_machine", + "train_mnist_distributed", +] + + +def conv_layer(layer_id, inputs, kernel_size, out_channels): + """Builds a convolutional layer with ReLU non-linearity. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + kernel_size: int. Width and height of the convolution kernel. The kernel is + assumed to be square. + out_channels: int. Number of output features per pixel. + + Returns: + preactivations: Tensor of shape [num_examples, width, height, out_channels]. + Values of the layer immediately before the activation function. + activations: Tensor of shape [num_examples, width, height, out_channels]. + Values of the layer immediately after the activation function. + params: Tuple of (kernel, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + layer = tf.layers.Conv2D( + out_channels, + kernel_size=[kernel_size, kernel_size], + kernel_initializer=tf.random_normal_initializer(stddev=0.01), + padding="SAME", + name="conv_%d" % layer_id) + preactivations = layer(inputs) + activations = tf.nn.relu(preactivations) + + # layer.weights is a list. This converts it a (hashable) tuple. + return preactivations, activations, (layer.kernel, layer.bias) + + +def max_pool_layer(layer_id, inputs, kernel_size, stride): + """Build a max-pooling layer. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + kernel_size: int. Width and height to pool over per input channel. The + kernel is assumed to be square. + stride: int. Step size between pooling operations. + + Returns: + Tensor of shape [num_examples, width/stride, height/stride, out_channels]. + Result of applying max pooling to 'inputs'. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + with tf.variable_scope("pool_%d" % layer_id): + return tf.nn.max_pool( + inputs, [1, kernel_size, kernel_size, 1], [1, stride, stride, 1], + padding="SAME", + name="pool") + + +def linear_layer(layer_id, inputs, output_size): + """Builds the final linear layer for an MNIST classification problem. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, width, height, in_channels]. Each row + corresponds to a single example. + output_size: int. Number of output dims per example. + + Returns: + activations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately after the activation function. + params: Tuple of (weights, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + pre, _, params = mlp.fc_layer(layer_id, inputs, output_size) + return pre, params + + +def build_model(examples, labels, num_labels, layer_collection): + """Builds a ConvNet classification model. + + Args: + examples: Tensor of shape [num_examples, num_features]. Represents inputs of + model. + labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted + by softmax for each example. + num_labels: int. Number of distinct values 'labels' can take on. + layer_collection: LayerCollection instance. Layers will be registered here. + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensor representing model's accuracy. + """ + # Build a ConvNet. For each layer with parameters, we'll keep track of the + # preactivations, activations, weights, and bias. + tf.logging.info("Building model.") + pre0, act0, params0 = conv_layer( + layer_id=0, inputs=examples, kernel_size=5, out_channels=16) + act1 = max_pool_layer(layer_id=1, inputs=act0, kernel_size=3, stride=2) + pre2, act2, params2 = conv_layer( + layer_id=2, inputs=act1, kernel_size=5, out_channels=16) + act3 = max_pool_layer(layer_id=3, inputs=act2, kernel_size=3, stride=2) + flat_act3 = tf.reshape(act3, shape=[-1, int(np.prod(act3.shape[1:4]))]) + logits, params4 = linear_layer( + layer_id=4, inputs=flat_act3, output_size=num_labels) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + accuracy = tf.reduce_mean( + tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) + + tf.summary.scalar("loss", loss) + tf.summary.scalar("accuracy", accuracy) + + # Register parameters. K-FAC needs to know about the inputs, outputs, and + # parameters of each conv/fully connected layer and the logits powering the + # posterior probability over classes. + tf.logging.info("Building LayerCollection.") + layer_collection.register_conv2d(params0, (1, 1, 1, 1), "SAME", examples, + pre0) + layer_collection.register_conv2d(params2, (1, 1, 1, 1), "SAME", act1, pre2) + layer_collection.register_fully_connected(params4, flat_act3, logits) + layer_collection.register_categorical_predictive_distribution( + logits, name="logits") + + return loss, accuracy + + +def minimize_loss_single_machine(loss, + accuracy, + layer_collection, + session_config=None): + """Minimize loss with K-FAC on a single machine. + + A single Session is responsible for running all of K-FAC's ops. + + Args: + loss: 0-D Tensor. Loss to be minimized. + accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + session_config: None or tf.ConfigProto. Configuration for tf.Session(). + + Returns: + final value for 'accuracy'. + """ + # Train with K-FAC. + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + train_op = optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + global_step_, loss_, accuracy_, _, _ = sess.run( + [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) + + if global_step_ % 100 == 0: + sess.run(optimizer.inv_update_op) + + if global_step_ % 100 == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", + global_step_, loss_, accuracy_) + + return accuracy_ + + +def _is_gradient_task(task_id, num_tasks): + """Returns True if this task should update the weights.""" + if num_tasks < 3: + return True + return 0 <= task_id < 0.6 * num_tasks + + +def _is_cov_update_task(task_id, num_tasks): + """Returns True if this task should update K-FAC's covariance matrices.""" + if num_tasks < 3: + return False + return 0.6 * num_tasks <= task_id < num_tasks - 1 + + +def _is_inv_update_task(task_id, num_tasks): + """Returns True if this task should update K-FAC's preconditioner.""" + if num_tasks < 3: + return False + return task_id == num_tasks - 1 + + +def _num_gradient_tasks(num_tasks): + """Number of tasks that will update weights.""" + if num_tasks < 3: + return num_tasks + return int(np.ceil(0.6 * num_tasks)) + + +def minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, master, + checkpoint_dir, loss, accuracy, layer_collection): + """Minimize loss with an synchronous implementation of K-FAC. + + Different tasks are responsible for different parts of K-FAC's Ops. The first + 60% of tasks update weights; the next 20% accumulate covariance statistics; + the last 20% invert the matrices used to precondition gradients. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. If 0, + parameter servers are not used. + master: string. IP and port of TensorFlow runtime process. Set to empty + string to run locally. + checkpoint_dir: string or None. Path to store checkpoints under. + loss: 0-D Tensor. Loss to be minimized. + accuracy: dict mapping strings to 0-D Tensors. Additional accuracy to + run with each step. + layer_collection: LayerCollection instance describing model architecture. + Used by K-FAC to construct preconditioner. + + Returns: + final value for 'accuracy'. + + Raises: + ValueError: if task_id >= num_worker_tasks. + """ + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=0.0001, + cov_ema_decay=0.95, + damping=0.001, + layer_collection=layer_collection, + momentum=0.9) + inv_update_queue = oq.OpQueue(optimizer.inv_updates_dict.values()) + sync_optimizer = tf.train.SyncReplicasOptimizer( + opt=optimizer, + replicas_to_aggregate=_num_gradient_tasks(num_worker_tasks)) + train_op = sync_optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + is_chief = (task_id == 0) + hooks = [sync_optimizer.make_session_run_hook(is_chief)] + with tf.train.MonitoredTrainingSession( + master=master, + is_chief=is_chief, + checkpoint_dir=checkpoint_dir, + hooks=hooks, + stop_grace_period_secs=0) as sess: + while not sess.should_stop(): + # Choose which op this task is responsible for running. + if _is_gradient_task(task_id, num_worker_tasks): + learning_op = train_op + elif _is_cov_update_task(task_id, num_worker_tasks): + learning_op = optimizer.cov_update_op + elif _is_inv_update_task(task_id, num_worker_tasks): + # TODO(duckworthd): Running this op before cov_update_op has been run a + # few times can result in "InvalidArgumentError: Cholesky decomposition + # was not successful." Delay running this op until cov_update_op has + # been run a few times. + learning_op = inv_update_queue.next_op(sess) + else: + raise ValueError("Which op should task %d do?" % task_id) + + global_step_, loss_, accuracy_, _ = sess.run( + [global_step, loss, accuracy, learning_op]) + tf.logging.info("global_step: %d | loss: %f | accuracy: %s", global_step_, + loss_, accuracy_) + + return accuracy_ + + +def train_mnist_single_machine(data_dir, num_epochs, use_fake_data=False): + """Train a ConvNet on MNIST. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=128, + use_fake_data=use_fake_data, + flatten_images=False) + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + loss, accuracy = build_model( + examples, labels, num_labels=10, layer_collection=layer_collection) + + # Fit model. + return minimize_loss_single_machine(loss, accuracy, layer_collection) + + +def train_mnist_multitower(data_dir, num_epochs, num_towers, + use_fake_data=True): + """Train a ConvNet on MNIST. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + num_towers: int. Number of CPUs to split inference across. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + tower_batch_size = 128 + batch_size = tower_batch_size * num_towers + tf.logging.info( + ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " + "tower batch size.") % (batch_size, num_towers, tower_batch_size)) + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=batch_size, + use_fake_data=use_fake_data, + flatten_images=False) + + # Split minibatch across towers. + examples = tf.split(examples, num_towers) + labels = tf.split(labels, num_towers) + + # Build an MLP. Each tower's layers will be added to the LayerCollection. + layer_collection = lc.LayerCollection() + tower_results = [] + for tower_id in range(num_towers): + with tf.device("/cpu:%d" % tower_id): + with tf.name_scope("tower%d" % tower_id): + with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): + tf.logging.info("Building tower %d." % tower_id) + tower_results.append( + build_model(examples[tower_id], labels[tower_id], 10, + layer_collection)) + losses, accuracies = zip(*tower_results) + + # Average across towers. + loss = tf.reduce_mean(losses) + accuracy = tf.reduce_mean(accuracies) + + # Fit model. + session_config = tf.ConfigProto( + allow_soft_placement=False, device_count={ + "CPU": num_towers + }) + return minimize_loss_single_machine( + loss, accuracy, layer_collection, session_config=session_config) + + +def train_mnist_distributed(task_id, + num_worker_tasks, + num_ps_tasks, + master, + data_dir, + num_epochs, + use_fake_data=False): + """Train a ConvNet on MNIST. + + Args: + task_id: int. Integer in [0, num_worker_tasks). ID for this worker. + num_worker_tasks: int. Number of workers in this distributed training setup. + num_ps_tasks: int. Number of parameter servers holding variables. + master: string. IP and port of TensorFlow runtime process. + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=128, + use_fake_data=use_fake_data, + flatten_images=False) + + # Build a ConvNet. + layer_collection = lc.LayerCollection() + with tf.device(tf.train.replica_device_setter(num_ps_tasks)): + loss, accuracy = build_model( + examples, labels, num_labels=10, layer_collection=layer_collection) + + # Fit model. + checkpoint_dir = None if data_dir is None else os.path.join(data_dir, "kfac") + return minimize_loss_distributed(task_id, num_worker_tasks, num_ps_tasks, + master, checkpoint_dir, loss, accuracy, + layer_collection) + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/kfac/examples/convnet_mnist_main.py b/tensorflow/contrib/kfac/examples/convnet_mnist_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b0c6fbde198850c76af0bc1600dc23e926227229 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/convnet_mnist_main.py @@ -0,0 +1,57 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train a ConvNet on MNIST using K-FAC. + +See convnet.py for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import convnet + +FLAGS = None + + +def main(argv): + _ = argv + + if FLAGS.num_towers > 1: + convnet.train_mnist_multitower( + FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) + else: + convnet.train_mnist_single_machine(FLAGS.data_dir, num_epochs=200) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/mnist", + help="Directory to store dataset in.") + parser.add_argument( + "--num_towers", + type=int, + default=1, + help="Number of CPUs to split minibatch across.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mlp.py b/tensorflow/contrib/kfac/examples/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..4275ceadc210ff471109b596e1c9aa260ce31ab5 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mlp.py @@ -0,0 +1,241 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train an MLP on MNIST using K-FAC. + +This library fits a 3-layer, tanh-activated MLP on MNIST using K-FAC. After +~25k steps, this should reach perfect accuracy on the training set. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mnist + +lc = tf.contrib.kfac.layer_collection +opt = tf.contrib.kfac.optimizer + +__all__ = [ + "fc_layer", + "train_mnist", + "train_mnist_multitower", +] + + +def fc_layer(layer_id, inputs, output_size): + """Builds a fully connected layer. + + Args: + layer_id: int. Integer ID for this layer's variables. + inputs: Tensor of shape [num_examples, input_size]. Each row corresponds + to a single example. + output_size: int. Number of output dimensions after fully connected layer. + + Returns: + preactivations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately before the activation function. + activations: Tensor of shape [num_examples, output_size]. Values of the + layer immediately after the activation function. + params: Tuple of (weights, bias), parameters for this layer. + """ + # TODO(b/67004004): Delete this function and rely on tf.layers exclusively. + layer = tf.layers.Dense( + output_size, + kernel_initializer=tf.random_normal_initializer(), + name="fc_%d" % layer_id) + preactivations = layer(inputs) + activations = tf.nn.tanh(preactivations) + + # layer.weights is a list. This converts it a (hashable) tuple. + return preactivations, activations, (layer.kernel, layer.bias) + + +def build_model(examples, labels, num_labels, layer_collection): + """Builds an MLP classification model. + + Args: + examples: Tensor of shape [num_examples, num_features]. Represents inputs of + model. + labels: Tensor of shape [num_examples]. Contains integer IDs to be predicted + by softmax for each example. + num_labels: int. Number of distinct values 'labels' can take on. + layer_collection: LayerCollection instance describing model architecture. + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensor representing model's accuracy. + """ + # Build an MLP. For each layer, we'll keep track of the preactivations, + # activations, weights, and bias. + pre0, act0, params0 = fc_layer(layer_id=0, inputs=examples, output_size=128) + pre1, act1, params1 = fc_layer(layer_id=1, inputs=act0, output_size=64) + pre2, act2, params2 = fc_layer(layer_id=2, inputs=act1, output_size=32) + logits, _, params3 = fc_layer(layer_id=3, inputs=act2, output_size=num_labels) + loss = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits)) + accuracy = tf.reduce_mean( + tf.cast(tf.equal(labels, tf.argmax(logits, axis=1)), dtype=tf.float32)) + + # Register parameters. K-FAC needs to know about the inputs, outputs, and + # parameters of each layer and the logits powering the posterior probability + # over classes. + tf.logging.info("Building LayerCollection.") + layer_collection.register_fully_connected(params0, examples, pre0) + layer_collection.register_fully_connected(params1, act0, pre1) + layer_collection.register_fully_connected(params2, act1, pre2) + layer_collection.register_fully_connected(params3, act2, logits) + layer_collection.register_categorical_predictive_distribution( + logits, name="logits") + + return loss, accuracy + + +def minimize(loss, accuracy, layer_collection, session_config=None): + """Minimize 'loss' with KfacOptimizer. + + Args: + loss: 0-D Tensor. Loss to be minimized. + accuracy: 0-D Tensor. Accuracy of classifier on current minibatch. + layer_collection: LayerCollection instance. Describes layers in model. + session_config: tf.ConfigProto. Configuration for tf.Session(). + + Returns: + accuracy of classifier on final minibatch. + """ + # Train with K-FAC. We'll use a decreasing learning rate that's cut in 1/2 + # every 10k iterations. + tf.logging.info("Building KFAC Optimizer.") + global_step = tf.train.get_or_create_global_step() + optimizer = opt.KfacOptimizer( + learning_rate=tf.train.exponential_decay( + 0.00002, global_step, 10000, 0.5, staircase=True), + cov_ema_decay=0.95, + damping=0.0001, + layer_collection=layer_collection, + momentum=0.99) + train_op = optimizer.minimize(loss, global_step=global_step) + + tf.logging.info("Starting training.") + with tf.train.MonitoredTrainingSession(config=session_config) as sess: + while not sess.should_stop(): + # K-FAC has 3 primary ops, + # - train_op: Update the weights with the minibatch's gradient. + # - cov_update_op: Update statistics used for building K-FAC's + # preconditioner matrix. + # - inv_update_op: Update preconditioner matrix using statistics. + # + # The first 2 of these are cheap and should be done with each step. The + # latter is more expensive, and should be updated ~100 iterations. + global_step_, loss_, accuracy_, _, _ = sess.run( + [global_step, loss, accuracy, train_op, optimizer.cov_update_op]) + + if global_step_ % 100 == 0: + sess.run(optimizer.inv_update_op) + + if global_step_ % 100 == 0: + tf.logging.info("global_step: %d | loss: %f | accuracy: %f", + global_step_, loss_, accuracy_) + + return accuracy_ + + +def train_mnist(data_dir, num_epochs, use_fake_data=False): + """Train an MLP on MNIST. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tf.logging.info("Loading MNIST into memory.") + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=64, + flatten_images=True, + use_fake_data=use_fake_data) + + # Build an MLP. The model's layers will be added to the LayerCollection. + tf.logging.info("Building model.") + layer_collection = lc.LayerCollection() + loss, accuracy = build_model(examples, labels, 10, layer_collection) + + # Fit model. + minimize(loss, accuracy, layer_collection) + + +def train_mnist_multitower(data_dir, + num_epochs, + num_towers, + use_fake_data=False): + """Train an MLP on MNIST, splitting the minibatch across multiple towers. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the training set. + num_towers: int. Number of CPUs to split minibatch across. + use_fake_data: bool. If True, generate a synthetic dataset. + + Returns: + accuracy of model on the final minibatch of training data. + """ + # Load a dataset. + tower_batch_size = 64 + batch_size = tower_batch_size * num_towers + tf.logging.info( + ("Loading MNIST into memory. Using batch_size = %d = %d towers * %d " + "tower batch size.") % (batch_size, num_towers, tower_batch_size)) + examples, labels = mnist.load_mnist( + data_dir, + num_epochs=num_epochs, + batch_size=batch_size, + flatten_images=True, + use_fake_data=use_fake_data) + + # Split minibatch across towers. + examples = tf.split(examples, num_towers) + labels = tf.split(labels, num_towers) + + # Build an MLP. Each tower's layers will be added to the LayerCollection. + layer_collection = lc.LayerCollection() + tower_results = [] + for tower_id in range(num_towers): + with tf.device("/cpu:%d" % tower_id): + with tf.name_scope("tower%d" % tower_id): + with tf.variable_scope(tf.get_variable_scope(), reuse=(tower_id > 0)): + tf.logging.info("Building tower %d." % tower_id) + tower_results.append( + build_model(examples[tower_id], labels[tower_id], 10, + layer_collection)) + losses, accuracies = zip(*tower_results) + + # Average across towers. + loss = tf.reduce_mean(losses) + accuracy = tf.reduce_mean(accuracies) + + # Fit model. + session_config = tf.ConfigProto( + allow_soft_placement=False, device_count={ + "CPU": num_towers + }) + return minimize( + loss, accuracy, layer_collection, session_config=session_config) diff --git a/tensorflow/contrib/kfac/examples/mlp_mnist_main.py b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py new file mode 100644 index 0000000000000000000000000000000000000000..b318c71a568be2d717745579df24134ceb3b6a0b --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mlp_mnist_main.py @@ -0,0 +1,56 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +r"""Train an MLP on MNIST using K-FAC. + +See mlp.py for details. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import sys + +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp + +FLAGS = None + + +def main(argv): + _ = argv + if FLAGS.num_towers > 1: + mlp.train_mnist_multitower( + FLAGS.data_dir, num_epochs=200, num_towers=FLAGS.num_towers) + else: + mlp.train_mnist(FLAGS.data_dir, num_epochs=200) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_dir", + type=str, + default="/tmp/mnist", + help="Directory to store dataset in.") + parser.add_argument( + "--num_towers", + type=int, + default=1, + help="Number of CPUs to split minibatch across.") + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/kfac/examples/mnist.py b/tensorflow/contrib/kfac/examples/mnist.py new file mode 100644 index 0000000000000000000000000000000000000000..cf92c909f4b5201bc0ffda5703136f46c7058ec6 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/mnist.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities for loading MNIST into TensorFlow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +__all__ = [ + 'load_mnist', +] + + +def load_mnist(data_dir, + num_epochs, + batch_size, + flatten_images=True, + use_fake_data=False): + """Loads MNIST dataset into memory. + + Args: + data_dir: string. Directory to read MNIST examples from. + num_epochs: int. Number of passes to make over the dataset. + batch_size: int. Number of examples per minibatch. + flatten_images: bool. If True, [28, 28, 1]-shaped images are flattened into + [784]-shaped vectors. + use_fake_data: bool. If True, generate a synthetic dataset rather than + reading MNIST in. + + Returns: + examples: Tensor of shape [batch_size, 784] if 'flatten_images' is + True, else [batch_size, 28, 28, 1]. Each row is one example. + Values in [0, 1]. + labels: Tensor of shape [batch_size]. Indices of integer corresponding to + each example. Values in {0...9}. + """ + if use_fake_data: + rng = np.random.RandomState(42) + num_examples = batch_size * 4 + images = rng.rand(num_examples, 28 * 28) + if not flatten_images: + images = np.reshape(images, [num_examples, 28, 28, 1]) + labels = rng.randint(10, size=num_examples) + else: + mnist_data = tf.contrib.learn.datasets.mnist.read_data_sets( + data_dir, reshape=flatten_images) + num_examples = len(mnist_data.train.labels) + images = mnist_data.train.images + labels = mnist_data.train.labels + + dataset = tf.contrib.data.Dataset.from_tensor_slices((np.asarray( + images, dtype=np.float32), np.asarray(labels, dtype=np.int64))) + return (dataset.repeat(num_epochs).shuffle(num_examples).batch(batch_size) + .make_one_shot_iterator().get_next()) diff --git a/tensorflow/contrib/kfac/examples/tests/BUILD b/tensorflow/contrib/kfac/examples/tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ce7da95c124beaed4773d68ce0d0c41f187f7c9d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/BUILD @@ -0,0 +1,64 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "mlp_test", + size = "large", + srcs = ["mlp_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac/examples:mlp", + "//third_party/py/numpy", + ], +) + +py_test( + name = "convnet_test", + size = "large", + srcs = ["convnet_test.py"], + srcs_version = "PY2AND3", + tags = [ + "no_pip", + "notsan", + ], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac", + "//tensorflow/contrib/kfac/examples:convnet", + "//third_party/py/numpy", + ], +) + +py_test( + name = "mnist_test", + srcs = ["mnist_test.py"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/kfac/examples:mnist", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kfac/examples/tests/convnet_test.py b/tensorflow/contrib/kfac/examples/tests/convnet_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3c98c54ef6cbd527aa0035e0b6f40be961c6308d --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/convnet_test.py @@ -0,0 +1,163 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for convnet.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac import layer_collection as lc +from tensorflow.contrib.kfac.examples import convnet + + +class ConvNetTest(tf.test.TestCase): + + def testConvLayer(self): + with tf.Graph().as_default(): + pre, act, (w, b) = convnet.conv_layer( + layer_id=1, + inputs=tf.zeros([5, 3, 3, 2]), + kernel_size=3, + out_channels=5) + self.assertShapeEqual(np.zeros([5, 3, 3, 5]), pre) + self.assertShapeEqual(np.zeros([5, 3, 3, 5]), act) + self.assertShapeEqual(np.zeros([3, 3, 2, 5]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("conv_1", w.op.name) + self.assertIn("conv_1", b.op.name) + + def testMaxPoolLayer(self): + with tf.Graph().as_default(): + act = convnet.max_pool_layer( + layer_id=1, inputs=tf.zeros([5, 6, 6, 2]), kernel_size=5, stride=3) + self.assertShapeEqual(np.zeros([5, 2, 2, 2]), act) + self.assertEqual(act.op.name, "pool_1/pool") + + def testLinearLayer(self): + with tf.Graph().as_default(): + act, (w, b) = convnet.linear_layer( + layer_id=1, inputs=tf.zeros([5, 20]), output_size=5) + self.assertShapeEqual(np.zeros([5, 5]), act) + self.assertShapeEqual(np.zeros([20, 5]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([5]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("fc_1", w.op.name) + self.assertIn("fc_1", b.op.name) + + def testBuildModel(self): + with tf.Graph().as_default(): + x = tf.placeholder(tf.float32, [None, 6, 6, 3]) + y = tf.placeholder(tf.int64, [None]) + layer_collection = lc.LayerCollection() + loss, accuracy = convnet.build_model( + x, y, num_labels=5, layer_collection=layer_collection) + + # Ensure layers and logits were registered. + self.assertEqual(len(layer_collection.fisher_blocks), 3) + self.assertEqual(len(layer_collection.losses), 1) + + # Ensure inference doesn't crash. + with self.test_session() as sess: + sess.run(tf.global_variables_initializer()) + feed_dict = { + x: np.random.randn(10, 6, 6, 3).astype(np.float32), + y: np.random.randint(5, size=10).astype(np.int64), + } + sess.run([loss, accuracy], feed_dict=feed_dict) + + def _build_toy_problem(self): + """Construct a toy linear regression problem. + + Initial loss should be, + 2.5 = 0.5 * (1^2 + 2^2) + + Returns: + loss: 0-D Tensor representing loss to be minimized. + accuracy: 0-D Tensors representing model accuracy. + layer_collection: LayerCollection instance describing model architecture. + """ + x = np.asarray([[1.], [2.]]).astype(np.float32) + y = np.asarray([1., 2.]).astype(np.float32) + x, y = (tf.contrib.data.Dataset.from_tensor_slices((x, y)) + .repeat(100).batch(2).make_one_shot_iterator().get_next()) + w = tf.get_variable("w", shape=[1, 1], initializer=tf.zeros_initializer()) + y_hat = tf.matmul(x, w) + loss = tf.reduce_mean(0.5 * tf.square(y_hat - y)) + accuracy = loss + + layer_collection = lc.LayerCollection() + layer_collection.register_fully_connected(params=w, inputs=x, outputs=y_hat) + layer_collection.register_normal_predictive_distribution(y_hat) + + return loss, accuracy, layer_collection + + def testMinimizeLossSingleMachine(self): + with tf.Graph().as_default(): + loss, accuracy, layer_collection = self._build_toy_problem() + accuracy_ = convnet.minimize_loss_single_machine(loss, accuracy, + layer_collection) + self.assertLess(accuracy_, 1.0) + + def testMinimizeLossDistributed(self): + with tf.Graph().as_default(): + loss, accuracy, layer_collection = self._build_toy_problem() + accuracy_ = convnet.minimize_loss_distributed( + task_id=0, + num_worker_tasks=1, + num_ps_tasks=0, + master="", + checkpoint_dir=None, + loss=loss, + accuracy=accuracy, + layer_collection=layer_collection) + self.assertLess(accuracy_, 1.0) + + def testTrainMnistSingleMachine(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + # + # Ideally, we should check that accuracy increases as the model converges, + # but there are too few parameters for the model to effectively memorize + # the training set the way an MLP can. + convnet.train_mnist_single_machine( + data_dir=None, num_epochs=1, use_fake_data=True) + + def testTrainMnistMultitower(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + convnet.train_mnist_multitower( + data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + + def testTrainMnistDistributed(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + convnet.train_mnist_distributed( + task_id=0, + num_worker_tasks=1, + num_ps_tasks=0, + master="", + data_dir=None, + num_epochs=1, + use_fake_data=True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mlp_test.py b/tensorflow/contrib/kfac/examples/tests/mlp_test.py new file mode 100644 index 0000000000000000000000000000000000000000..34a942d27f64e2583c686c2ba3240bc636ed918b --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/mlp_test.py @@ -0,0 +1,58 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mlp.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mlp + + +class MlpTest(tf.test.TestCase): + + def testFcLayer(self): + with tf.Graph().as_default(): + pre, act, (w, b) = mlp.fc_layer( + layer_id=1, inputs=tf.zeros([5, 3]), output_size=10) + self.assertShapeEqual(np.zeros([5, 10]), pre) + self.assertShapeEqual(np.zeros([5, 10]), act) + self.assertShapeEqual(np.zeros([3, 10]), tf.convert_to_tensor(w)) + self.assertShapeEqual(np.zeros([10]), tf.convert_to_tensor(b)) + self.assertIsInstance(w, tf.Variable) + self.assertIsInstance(b, tf.Variable) + self.assertIn("fc_1/", w.op.name) + self.assertIn("fc_1/", b.op.name) + + def testTrainMnist(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + # + # Ideally, we should check that accuracy increases as the model converges, + # but that takes a non-trivial amount of compute. + mlp.train_mnist(data_dir=None, num_epochs=1, use_fake_data=True) + + def testTrainMnistMultitower(self): + with tf.Graph().as_default(): + # Ensure model training doesn't crash. + mlp.train_mnist_multitower( + data_dir=None, num_epochs=1, num_towers=2, use_fake_data=True) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/kfac/examples/tests/mnist_test.py b/tensorflow/contrib/kfac/examples/tests/mnist_test.py new file mode 100644 index 0000000000000000000000000000000000000000..92f84623573d3ad3af26b500fccfe533280d0199 --- /dev/null +++ b/tensorflow/contrib/kfac/examples/tests/mnist_test.py @@ -0,0 +1,72 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mnist.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.kfac.examples import mnist + + +class MnistTest(tf.test.TestCase): + + def testValues(self): + """Ensure values are in their expected range.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, num_epochs=1, batch_size=64, use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertTrue(np.all((0 <= examples_) & (examples_ < 1))) + self.assertTrue(np.all((0 <= labels_) & (labels_ < 10))) + + def testFlattenedShapes(self): + """Ensure images are flattened into their appropriate shape.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, + num_epochs=1, + batch_size=64, + flatten_images=True, + use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertEqual(examples_.shape, (64, 784)) + self.assertEqual(labels_.shape, (64,)) + + def testNotFlattenedShapes(self): + """Ensure non-flattened images are their appropriate shape.""" + with tf.Graph().as_default(): + examples, labels = mnist.load_mnist( + data_dir=None, + num_epochs=1, + batch_size=64, + flatten_images=False, + use_fake_data=True) + + with self.test_session() as sess: + examples_, labels_ = sess.run([examples, labels]) + self.assertEqual(examples_.shape, (64, 28, 28, 1)) + self.assertEqual(labels_.shape, (64,)) + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/kfac/g3doc/autoencoder.png b/tensorflow/contrib/kfac/g3doc/autoencoder.png new file mode 100644 index 0000000000000000000000000000000000000000..20f93c77034f3355653a6a260cccdad29c080eaf Binary files /dev/null and b/tensorflow/contrib/kfac/g3doc/autoencoder.png differ diff --git a/tensorflow/contrib/kfac/python/kernel_tests/BUILD b/tensorflow/contrib/kfac/python/kernel_tests/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..5d86373a232d55cd281d06cfc0606f4224d8f669 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/BUILD @@ -0,0 +1,156 @@ +package(default_visibility = ["//visibility:private"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_test( + name = "estimator_test", + srcs = ["estimator_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_estimator", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "fisher_factors_test", + srcs = ["fisher_factors_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "fisher_blocks_test", + srcs = ["fisher_blocks_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_blocks", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:state_ops", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "layer_collection_test", + srcs = ["layer_collection_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:fisher_factors", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:variable_scope", + ], +) + +py_test( + name = "optimizer_test", + srcs = ["optimizer_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:kfac_optimizer", + "//tensorflow/contrib/kfac/python/ops:layer_collection", + "//tensorflow/contrib/kfac/python/ops:loss_functions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + ], +) + +py_test( + name = "utils_test", + srcs = ["utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:random_seed", + "//third_party/py/numpy", + ], +) + +py_test( + name = "op_queue_test", + srcs = ["op_queue_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:op_queue", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + ], +) + +py_test( + name = "loss_functions_test", + srcs = ["loss_functions_test.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/kfac/python/ops:loss_functions", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b52a7b52a7efd4292ad514c5a744c4da07082142 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/estimator_test.py @@ -0,0 +1,99 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import estimator +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + +_ALL_ESTIMATION_MODES = ["gradients", "empirical", "curvature_prop", "exact"] + + +class EstimatorTest(test.TestCase): + + def setUp(self): + self._graph = ops.Graph() + with self._graph.as_default(): + self.layer_collection = lc.LayerCollection() + + self.inputs = random_ops.random_normal((2, 2), dtype=dtypes.float32) + self.weights = variable_scope.get_variable( + "w", shape=(2, 2), dtype=dtypes.float32) + self.bias = variable_scope.get_variable( + "b", initializer=init_ops.zeros_initializer(), shape=(2, 1)) + self.output = math_ops.matmul(self.inputs, self.weights) + self.bias + + # Only register the weights. + self.layer_collection.register_fully_connected( + params=(self.weights,), inputs=self.inputs, outputs=self.output) + + self.outputs = math_ops.tanh(self.output) + self.targets = array_ops.zeros_like(self.outputs) + self.layer_collection.register_categorical_predictive_distribution( + logits=self.outputs, targets=self.targets) + + def testEstimatorInitManualRegistration(self): + with self._graph.as_default(): + # We should be able to build an estimator for only the registered vars. + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + + # Check that we throw an error if we try to build an estimator for vars + # that were not manually registered. + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights, self.bias], 0.1, 0.2, + self.layer_collection) + + # Check that we throw an error if we don't include registered variables, + # i.e. self.weights + with self.assertRaises(ValueError): + estimator.FisherEstimator([], 0.1, 0.2, self.layer_collection) + + @test.mock.patch.object(utils.SubGraph, "variable_uses", return_value=42) + def testVariableWrongNumberOfUses(self, mock_uses): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection) + + def testInvalidEstimationMode(self): + with self.assertRaises(ValueError): + estimator.FisherEstimator([self.weights], 0.1, 0.2, self.layer_collection, + "not_a_real_mode") + + def testModeListCorrect(self): + with self._graph.as_default(): + est = estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection) + self.assertItemsEqual(_ALL_ESTIMATION_MODES, est._gradient_fns.keys()) + + def testAllModesBuild(self): + for mode in _ALL_ESTIMATION_MODES: + with self._graph.as_default(): + estimator.FisherEstimator([self.weights], 0.1, 0.2, + self.layer_collection, mode) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2b5c6cace9cd18f4cc5590ff55a9b39680a381 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_blocks_test.py @@ -0,0 +1,786 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_blocks.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +def _make_psd(dim): + """Constructs a PSD matrix of the given dimension.""" + mat = np.ones((dim, dim), dtype=np.float32) + mat[np.arange(dim), np.arange(dim)] = 2. + np.arange(dim) + return array_ops.constant(mat) + + +class FullFBTest(test.TestCase): + + def testFullFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testFullFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.FullFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = (array_ops.constant([2., 3.]), array_ops.constant(4.)) + damping = 0.5 + block.instantiate_factors((grads,), damping) + + # Make sure our inverse is something other than the identity. + sess.run(state_ops.assign(block._factor._cov, _make_psd(3))) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class NaiveDiagonalFBTest(test.TestCase): + + def testNaiveDiagonalFBInitSingleTensor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testNaiveDiagonalFBInitTensorTuple(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + self.assertAllEqual(params, block.tensors_to_compute_grads()) + + def testInstantiateFactors(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors(grads, 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + block.instantiate_factors((grads,), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + + vector = array_ops.ones(3,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.constant([[1.], [2.]]) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = params**2 + block.instantiate_factors((grads,), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_inverse_update_ops()) + vector = array_ops.ones(2,) * 2 + output = block.multiply_inverse(vector) + + self.assertAllClose(sess.run(vector * 2 / 3.), sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = (array_ops.constant([1., 2.]), array_ops.constant(3.)) + block = fb.NaiveDiagonalFB(lc.LayerCollection(), params) + block.register_additional_minibatch(32) + grads = (params[0]**2, math_ops.sqrt(params[1])) + damping = 0.5 + block.instantiate_factors((grads,), damping) + + cov = array_ops.reshape(array_ops.constant([2., 3., 4.]), [-1, 1]) + sess.run(state_ops.assign(block._factor._cov, cov)) + sess.run(block._factor.make_inverse_update_ops()) + + v_flat = np.array([4., 5., 6.], dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(3)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class FullyConnectedDiagonalFB(test.TestCase): + + def setUp(self): + super(FullyConnectedDiagonalFB, self).setUp() + + self.batch_size = 4 + self.input_size = 6 + self.output_size = 3 + + self.inputs = np.random.randn(self.batch_size, self.input_size).astype( + np.float32) + self.outputs = np.zeros([self.batch_size, self.output_size]).astype( + np.float32) + self.output_grads = np.random.randn(self.batch_size, + self.output_size).astype(np.float32) + self.w = np.random.randn(self.input_size, self.output_size).astype( + np.float32) + self.b = np.random.randn(self.output_size).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, np.ones([self.batch_size, 1])], axis=1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads) + + def buildDiagonalFisherApproximation(self, inputs, output_grads): + """Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[outer(input, output_grad)] + + where the expectation is taken over examples. + + Args: + inputs: np.array of shape [batch_size, input_size]. + output_grads: np.array of shape [batch_size, output_size]. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + input_size * output_size. + """ + batch_size = inputs.shape[0] + assert output_grads.shape[0] == batch_size + input_size = inputs.shape[1] + output_size = output_grads.shape[1] + fisher_diag = np.zeros((input_size, output_size)) + for i in range(batch_size): + fisher_diag += np.square(np.outer(inputs[i], output_grads[i])) + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape( + [self.input_size, self.output_size]) + + self.assertAllClose(expected_result, result) + + def testRegisterAdditionalMinibatch(self): + """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, + np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + expected_result = self.fisherApprox(True).dot( + np.concatenate([self.w.flatten(), self.b.flatten()])) + expected_result = expected_result.reshape( + [self.input_size + 1, self.output_size]) + expected_result = (expected_result[:-1], expected_result[-1]) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.FullyConnectedDiagonalFB( + lc.LayerCollection(), has_bias=isinstance(params, (tuple, list))) + for (i, o) in zip(inputs, outputs): + block.register_additional_minibatch(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class FullyConnectedKFACBasicFBTest(test.TestCase): + + def testFullyConnectedKFACBasicFBInit(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([1., 2.]) + outputs = array_ops.constant([3., 4.]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection()) + block.register_additional_minibatch(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testInstantiateFactorsHasBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=True) + block.register_additional_minibatch(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testInstantiateFactorsNoBias(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) + + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2., 3.], [3., 4., 5.], [5., 6., 7.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = ( + np.arange(2, 6).reshape(2, 2).astype(np.float32), # + np.arange(1, 3).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + output[0]) + self.assertAllClose([0.343146, 0.686291], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + inputs = array_ops.constant([[1., 2.], [3., 4.]]) + outputs = array_ops.constant([[3., 4.], [5., 6.]]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(2, 6).reshape(2, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([[0.686291, 1.029437], [1.372583, 1.715729]], + sess.run(output)) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + input_dim, output_dim = 3, 2 + inputs = array_ops.zeros([32, input_dim]) + outputs = array_ops.zeros([32, output_dim]) + params = array_ops.zeros([input_dim, output_dim]) + block = fb.FullyConnectedKFACBasicFB(lc.LayerCollection(), has_bias=False) + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(([grads],), damping) + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(3))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(6, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(6)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +class ConvDiagonalFBTest(test.TestCase): + + def setUp(self): + super(ConvDiagonalFBTest, self).setUp() + + self.batch_size = 2 + self.height = 8 + self.width = 4 + self.input_channels = 6 + self.output_channels = 3 + self.kernel_size = 1 + + self.inputs = np.random.randn(self.batch_size, self.height, self.width, + self.input_channels).astype(np.float32) + self.outputs = np.zeros( + [self.batch_size, self.height, self.width, + self.output_channels]).astype(np.float32) + self.output_grads = np.random.randn( + self.batch_size, self.height, self.width, self.output_channels).astype( + np.float32) + self.w = np.random.randn(self.kernel_size, self.kernel_size, + self.input_channels, self.output_channels).astype( + np.float32) + self.b = np.random.randn(self.output_channels).astype(np.float32) + + def fisherApprox(self, has_bias=False): + """Fisher approximation using default inputs.""" + if has_bias: + inputs = np.concatenate( + [self.inputs, + np.ones([self.batch_size, self.height, self.width, 1])], + axis=-1) + else: + inputs = self.inputs + return self.buildDiagonalFisherApproximation(inputs, self.output_grads, + self.kernel_size) + + def buildDiagonalFisherApproximation(self, inputs, output_grads, kernel_size): + r"""Builds explicit diagonal Fisher approximation. + + Fisher's diagonal is (d loss / d w)'s elements squared for + d/dw = E[\sum_{loc} outer(input_{loc}, output_grad_{loc})] + + where the expectation is taken over examples and the sum over (x, y) + locations upon which the convolution is applied. + + Args: + inputs: np.array of shape [batch_size, height, width, input_channels]. + output_grads: np.array of shape [batch_size, height, width, + output_channels]. + kernel_size: int. height and width of kernel. + + Returns: + Diagonal np.array of shape [num_params, num_params] for num_params = + kernel_size^2 * input_channels * output_channels. + """ + batch_size, height, width, input_channels = inputs.shape + assert output_grads.shape[0] == batch_size + assert output_grads.shape[1] == height + assert output_grads.shape[2] == width + output_channels = output_grads.shape[3] + + # If kernel_size == 1, then we don't need to worry about capturing context + # around the pixel upon which a convolution is applied. This makes testing + # easier. + assert kernel_size == 1, "kernel_size != 1 isn't supported." + num_locations = height * width + inputs = np.reshape(inputs, [batch_size, num_locations, input_channels]) + output_grads = np.reshape(output_grads, + [batch_size, num_locations, output_channels]) + + fisher_diag = np.zeros((input_channels, output_channels)) + for i in range(batch_size): + # Each example's approximation is a square(sum-of-outer-products). + example_fisher_diag = np.zeros((input_channels, output_channels)) + for j in range(num_locations): + example_fisher_diag += np.outer(inputs[i, j], output_grads[i, j]) + fisher_diag += np.square(example_fisher_diag) + + # Normalize by batch_size (not num_locations). + return np.diag(fisher_diag.flatten()) / batch_size + + def testMultiply(self): + result, _ = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct Fisher-vector product. + expected_result = self.fisherApprox().dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result) + + def testMultiplyInverse(self): + _, result = self.runFisherBlockOps(self.w, [self.inputs], [self.outputs], + [self.output_grads]) + + # Construct inverse Fisher-vector product. + expected_result = np.linalg.inv(self.fisherApprox()).dot(self.w.flatten()) + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels, + self.output_channels + ]) + + self.assertAllClose(expected_result, result, atol=1e-3) + + def testRegisterAdditionalMinibatch(self): + """Ensure 1 big minibatch and 2 small minibatches are equivalent.""" + multiply_result_big, multiply_inverse_result_big = self.runFisherBlockOps( + self.w, [self.inputs], [self.outputs], [self.output_grads]) + multiply_result_small, multiply_inverse_result_small = ( + self.runFisherBlockOps(self.w, + np.split(self.inputs, 2), + np.split(self.outputs, 2), + np.split(self.output_grads, 2))) + + self.assertAllClose(multiply_result_big, multiply_result_small) + self.assertAllClose(multiply_inverse_result_big, + multiply_inverse_result_small) + + def testMultiplyHasBias(self): + result, _ = self.runFisherBlockOps((self.w, self.b), [self.inputs], + [self.outputs], [self.output_grads]) + # Clone 'b' along 'input_channels' dimension. + b_filter = np.tile( + np.reshape(self.b, [1, 1, 1, self.output_channels]), + [self.kernel_size, self.kernel_size, 1, 1]) + params = np.concatenate([self.w, b_filter], axis=2) + expected_result = self.fisherApprox(True).dot(params.flatten()) + + # Extract 'b' from concatenated parameters. + expected_result = expected_result.reshape([ + self.kernel_size, self.kernel_size, self.input_channels + 1, + self.output_channels + ]) + expected_result = (expected_result[:, :, 0:-1, :], np.reshape( + expected_result[:, :, -1, :], [self.output_channels])) + + self.assertEqual(len(result), 2) + self.assertAllClose(expected_result[0], result[0]) + self.assertAllClose(expected_result[1], result[1]) + + def runFisherBlockOps(self, params, inputs, outputs, output_grads): + """Run Ops guaranteed by FisherBlock interface. + + Args: + params: Tensor or 2-tuple of Tensors. Represents weights or weights and + bias of this layer. + inputs: list of Tensors of shape [batch_size, input_size]. Inputs to + layer. + outputs: list of Tensors of shape [batch_size, output_size]. + Preactivations produced by layer. + output_grads: list of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to 'outputs'. + + Returns: + multiply_result: Result of FisherBlock.multiply(params) + multiply_inverse_result: Result of FisherBlock.multiply_inverse(params) + """ + with ops.Graph().as_default(), self.test_session() as sess: + inputs = as_tensors(inputs) + outputs = as_tensors(outputs) + output_grads = as_tensors(output_grads) + params = as_tensors(params) + + block = fb.ConvDiagonalFB( + lc.LayerCollection(), params, strides=[1, 1, 1, 1], padding='SAME') + for (i, o) in zip(inputs, outputs): + block.register_additional_minibatch(i, o) + + block.instantiate_factors((output_grads,), damping=0.0) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._factor.make_covariance_update_op(0.0)) + multiply_result = sess.run(block.multiply(params)) + multiply_inverse_result = sess.run(block.multiply_inverse(params)) + + return multiply_result, multiply_inverse_result + + +class ConvKFCBasicFBTest(test.TestCase): + + def _testConvKFCBasicFBInitParams(self, params): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + if isinstance(params, (list, tuple)): + params = [array_ops.constant(param) for param in params] + else: + params = array_ops.constant(params) + inputs = random_ops.random_normal((2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2)) + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, [1, 1, 1], 'SAME') + block.register_additional_minibatch(inputs, outputs) + + self.assertAllEqual([outputs], block.tensors_to_compute_grads()) + + def testConvKFCBasicFBInitParamsParamsTuple(self): + self._testConvKFCBasicFBInitParams([np.array([1., 2.]), np.array(3.)]) + + def testConvKFCBasicFBInitParamsParamsSingle(self): + self._testConvKFCBasicFBInitParams([np.array([1., 2.])]) + + def testMultiplyInverseTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = (np.arange(1, 15).reshape(7, 2).astype(np.float32), np.arange( + 2, 4).reshape(2, 1).astype(np.float32)) + output = block.multiply_inverse((array_ops.constant(vector[0]), + array_ops.constant(vector[1]))) + + output = sess.run(output) + self.assertAllClose([0.136455, 0.27291], output[0][0]) + self.assertAllClose([0.27291, 0.409365], output[1]) + + def testMultiplyInverseNotTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = random_ops.random_normal((2, 2, 2, 2)) + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) + self.assertFalse(block._has_bias) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 17).reshape(8, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseNotTupleWithBias(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = [random_ops.random_normal((2, 2, 2, 2))] + inputs = random_ops.random_normal((2, 2, 2, 2)) + outputs = random_ops.random_normal((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) + self.assertTrue(block._has_bias) + grads = outputs**2 + block.instantiate_factors(([grads],), 0.5) + + # Make sure our inverse is something other than the identity. + sess.run(tf_variables.global_variables_initializer()) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + vector = np.arange(1, 19).reshape(9, 2).astype(np.float32) + output = block.multiply_inverse(array_ops.constant(vector)) + + self.assertAllClose([0.136455, 0.27291], sess.run(output)[0]) + + def testMultiplyInverseAgainstExplicit(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + params = array_ops.zeros((2, 2, 2, 2)) + inputs = array_ops.zeros((2, 2, 2, 2)) + outputs = array_ops.zeros((2, 2, 2, 2)) + block = fb.ConvKFCBasicFB(lc.LayerCollection(), params, (1, 1, 1, 1), + 'SAME') + block.register_additional_minibatch(inputs, outputs) + grads = outputs**2 + damping = 0. # This test is only valid without damping. + block.instantiate_factors(([grads],), damping) + + sess.run(state_ops.assign(block._input_factor._cov, _make_psd(8))) + sess.run(state_ops.assign(block._output_factor._cov, _make_psd(2))) + sess.run(block._input_factor.make_inverse_update_ops()) + sess.run(block._output_factor.make_inverse_update_ops()) + + v_flat = np.arange(16, dtype=np.float32) + vector = utils.column_to_tensors(params, array_ops.constant(v_flat)) + output = block.multiply_inverse(vector) + output_flat = sess.run(utils.tensors_to_column(output)).ravel() + + full = sess.run(block.full_fisher_block()) + explicit = np.dot(np.linalg.inv(full + damping * np.eye(16)), v_flat) + + self.assertAllClose(output_flat, explicit) + + +def as_tensors(tensor_or_tuple): + """Converts a potentially nested tuple of np.array to Tensors.""" + if isinstance(tensor_or_tuple, (tuple, list)): + return tuple(as_tensors(t) for t in tensor_or_tuple) + return ops.convert_to_tensor(tensor_or_tuple) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fbb3d219139a4bc05253841a89e73645ef37dddd --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py @@ -0,0 +1,455 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.fisher_factors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import fisher_factors as ff +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +class FisherFactorTestingDummy(ff.FisherFactor): + """Dummy class to test the non-abstract methods on ff.FisherFactor.""" + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + raise NotImplementedError + + @property + def _num_sources(self): + return 1 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + +class InverseProvidingFactorTestingDummy(ff.InverseProvidingFactor): + """Dummy class to test the non-abstract methods on ff.InverseProvidingFactor. + """ + + def __init__(self, shape): + self._shape = shape + super(InverseProvidingFactorTestingDummy, self).__init__() + + @property + def _var_scope(self): + return 'dummy/a_b_c' + + @property + def _cov_shape(self): + return self._shape + + @property + def _num_sources(self): + return 1 + + def _compute_new_cov(self): + raise NotImplementedError + + def instantiate_covariance(self): + pass + + +class NumericalUtilsTest(test.TestCase): + + def testComputeCovAgainstNumpy(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + x = npr.randn(100, 3) + cov = ff._compute_cov(array_ops.constant(x)) + np_cov = np.dot(x.T, x) / x.shape[0] + + self.assertAllClose(sess.run(cov), np_cov) + + def testComputeCovAgainstNumpyWithAlternativeNormalizer(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + random_seed.set_random_seed(200) + + normalizer = 10. + x = npr.randn(100, 3) + cov = ff._compute_cov(array_ops.constant(x), normalizer) + np_cov = np.dot(x.T, x) / normalizer + + self.assertAllClose(sess.run(cov), np_cov) + + def testAppendHomog(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + npr.seed(0) + + m, n = 3, 4 + a = npr.randn(m, n) + a_homog = ff._append_homog(array_ops.constant(a)) + np_result = np.hstack([a, np.ones((m, 1))]) + + self.assertAllClose(sess.run(a_homog), np_result) + + +class NameStringUtilFunctionTest(test.TestCase): + + def _make_tensor(self): + x = array_ops.placeholder(dtypes.float64, (3, 1)) + w = array_ops.constant(npr.RandomState(0).randn(3, 3)) + y = math_ops.matmul(w, x) + g = gradients_impl.gradients(y, x)[0] + return g + + def testScopeStringFromParamsSingleTensor(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_params(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScopeStringFromParamsMultipleTensors(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params((x, y)) + self.assertEqual('Const_Const_1', scope_string) + + def testScopeStringFromParamsMultipleTypes(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + scope_string = ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, + (x, y)]) + self.assertEqual('1-2-3_foo_True_4_Const__Const_1', scope_string) + + def testScopeStringFromParamsUnsupportedType(self): + with tf_ops.Graph().as_default(): + x = array_ops.constant(1,) + y = array_ops.constant(2,) + unsupported = 1.2 # Floats are not supported. + with self.assertRaises(ValueError): + ff.scope_string_from_params([[1, 2, 3], 'foo', True, 4, (x, y), + unsupported]) + + def testScopeStringFromName(self): + with tf_ops.Graph().as_default(): + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual('gradients_MatMul_grad_MatMul_1', scope_string) + + def testScalarOrTensorToString(self): + with tf_ops.Graph().as_default(): + self.assertEqual(ff.scalar_or_tensor_to_string(5.), repr(5.)) + + g = self._make_tensor() + scope_string = ff.scope_string_from_name(g) + self.assertEqual(ff.scalar_or_tensor_to_string(g), scope_string) + + +class FisherFactorTest(test.TestCase): + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + +class InverseProvidingFactorTest(test.TestCase): + + def testRegisterDampedInverse(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [2, 2] + factor = InverseProvidingFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + dampings = 0.1, 1e-1, 0.00001, 1e-5 + + for damping in dampings: + factor.register_damped_inverse(damping) + + self.assertEqual(set(dampings), set(factor._inverses_by_damping.keys())) + inv = factor._inverses_by_damping[dampings[0]] + self.assertEqual(inv, factor._inverses_by_damping[dampings[1]]) + self.assertNotEqual(inv, factor._inverses_by_damping[dampings[2]]) + self.assertEqual(factor._inverses_by_damping[dampings[2]], + factor._inverses_by_damping[dampings[3]]) + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + self.assertListEqual([inv, factor._inverses_by_damping[dampings[2]]], + factor_vars) + self.assertEqual(shape, inv.get_shape()) + + def testRegisterMatpower(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + shape = [3, 3] + factor = InverseProvidingFactorTestingDummy(shape) + factor_var_scope = 'dummy/a_b_c' + + factor.register_matpower(1, 0.5) + factor.register_matpower(2, 0.5) + + self.assertEqual( + set([(1, 0.5), (2, 0.5)]), + set(factor._matpower_by_exp_and_damping.keys())) + factor_vars = tf_ops.get_collection(tf_ops.GraphKeys.GLOBAL_VARIABLES, + factor_var_scope) + matpower1 = factor.get_matpower(1, 0.5) + matpower2 = factor.get_matpower(2, 0.5) + self.assertListEqual([matpower1, matpower2], factor_vars) + + self.assertEqual(shape, matpower1.get_shape()) + self.assertEqual(shape, matpower2.get_shape()) + + def testMakeInverseUpdateOps(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + factor = FisherFactorTestingDummy() + + self.assertEqual(0, len(factor.make_inverse_update_ops())) + + def testMakeInverseUpdateOpsManyInversesEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[1., 2.], [3., 4.]]) + factor = InverseProvidingFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + factor.register_damped_inverse(1. / i) + ops = factor.make_inverse_update_ops() + self.assertEqual(ff.EIGENVALUE_DECOMPOSITION_THRESHOLD, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + new_invs = [] + for i in range(1, ff.EIGENVALUE_DECOMPOSITION_THRESHOLD + 1): + # The inverse op will assign the damped inverse of cov to the inv var. + sess.run(ops[i - 1]) + new_invs.append(sess.run(factor._inverses_by_damping[1. / i])) + # We want to see that the new invs are all different from each other. + for i in range(len(new_invs)): + for j in range(i + 1, len(new_invs)): + # Just check the first element. + self.assertNotEqual(new_invs[i][0][0], new_invs[j][0][0]) + + def testMakeInverseUpdateOpsMatPowerEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[6., 2.], [2., 4.]]) + factor = InverseProvidingFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + exp = 2 # NOTE(mattjj): must be int to test with np.linalg.matrix_power + damping = 0.5 + + factor.register_matpower(exp, damping) + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + sess.run(ops[0]) + matpower = sess.run(factor._matpower_by_exp_and_damping[(exp, damping)]) + matpower_np = np.linalg.matrix_power(cov + np.eye(2) * damping, exp) + self.assertAllClose(matpower, matpower_np) + + def testMakeInverseUpdateOpsNoEigenDecomp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + cov = np.array([[5., 2.], [2., 4.]]) # NOTE(mattjj): must be symmetric + factor = InverseProvidingFactorTestingDummy(cov.shape) + factor._cov = array_ops.constant(cov, dtype=dtypes.float32) + + factor.register_damped_inverse(0) + ops = factor.make_inverse_update_ops() + self.assertEqual(1, len(ops)) + + sess.run(tf_variables.global_variables_initializer()) + # The inverse op will assign the damped inverse of cov to the inv var. + old_inv = sess.run(factor._inverses_by_damping[0]) + self.assertAllClose( + sess.run(ff.inverse_initializer(cov.shape, dtypes.float32)), old_inv) + + sess.run(ops) + new_inv = sess.run(factor._inverses_by_damping[0]) + self.assertAllClose(new_inv, np.linalg.inv(cov)) + + +class FullFactorTest(test.TestCase): + + def testFullFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullFactor((tensor,), 32) + self.assertEqual([6, 6], factor.get_cov().get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.FullFactor((tensor,), 2) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75, 0.5], [0.5, 1.5]], new_cov) + + +class NaiveDiagonalFactorTest(test.TestCase): + + def testNaiveDiagonalFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 32) + self.assertEqual([6, 1], factor.get_cov().get_shape().as_list()) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([1., 2.], name='a/b/c') + factor = ff.NaiveDiagonalFactor((tensor,), 2) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[0.75], [1.5]], new_cov) + + +class FullyConnectedKroneckerFactorTest(test.TestCase): + + def _testFullyConnectedKroneckerFactorInit(self, has_bias, final_shape): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=has_bias) + self.assertEqual(final_shape, factor.get_cov().get_shape().as_list()) + + def testFullyConnectedKroneckerFactorInitNoBias(self): + self._testFullyConnectedKroneckerFactorInit(False, [3, 3]) + + def testFullyConnectedKroneckerFactorInitWithBias(self): + self._testFullyConnectedKroneckerFactorInit(True, [4, 4]) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor((tensor,), has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5, 1], [3.5, 5.5, 1.5], [1, 1.5, 1]], new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant([[1., 2.], [3., 4.]], name='a/b/c') + factor = ff.FullyConnectedKroneckerFactor((tensor,)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[3, 3.5], [3.5, 5.5]], new_cov) + + +class ConvInputKroneckerFactorTest(test.TestCase): + + def testConvInputKroneckerFactorInitNoBias(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=False) + self.assertEqual([1 * 2 * 3, 1 * 2 * 3], + factor.get_cov().get_shape().as_list()) + + def testConvInputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 3, 4), 3, 2, has_bias=True) + self.assertEqual([1 * 2 * 3 + 1, 1 * 2 * 3 + 1], + factor.get_cov().get_shape().as_list()) + + def testMakeCovarianceUpdateOpWithBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant( + np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + factor = ff.ConvInputKroneckerFactor( + tensor, (1, 2, 1, 1), [1, 1, 1, 1], 'SAME', has_bias=True) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[34.375, 37, 3.125], [37, 41, 3.5], [3.125, 3.5, 1]], + new_cov) + + def testMakeCovarianceUpdateOpNoBias(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = array_ops.constant( + np.arange(1., 17.).reshape(2, 2, 2, 2), dtype=dtypes.float32) + factor = ff.ConvInputKroneckerFactor(tensor, (1, 2, 1, 1), [1, 1, 1, 1], + 'SAME') + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[34.375, 37], [37, 41]], new_cov) + + +class ConvOutputKroneckerFactorTest(test.TestCase): + + def testConvOutputKroneckerFactorInit(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3, 4, 5), name='a/b/c') + factor = ff.ConvOutputKroneckerFactor((tensor,)) + self.assertEqual([5, 5], factor.get_cov().get_shape().as_list()) + + def testConvOutputKroneckerFactorInitNotEnoughDims(self): + with tf_ops.Graph().as_default(): + random_seed.set_random_seed(200) + tensor = array_ops.ones((2, 3), name='a/b/c') + with self.assertRaises(IndexError): + ff.ConvOutputKroneckerFactor(tensor) + + def testMakeCovarianceUpdateOp(self): + with tf_ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + tensor = np.arange(1, 17).reshape(2, 2, 2, 2).astype(np.float32) + factor = ff.ConvOutputKroneckerFactor((array_ops.constant(tensor),)) + + sess.run(tf_variables.global_variables_initializer()) + new_cov = sess.run(factor.make_covariance_update_op(.5)) + self.assertAllClose([[43, 46.5], [46.5, 51.5]], new_cov) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py new file mode 100644 index 0000000000000000000000000000000000000000..524e8338fde9bb20586b15c33ba2055e852baa01 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/layer_collection_test.py @@ -0,0 +1,469 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.layer_collection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import layer_collection +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test + + +class MockFisherBlock(object): + """A fake FisherBlock.""" + + num_registered_minibatches = 2 + + def __init__(self, name='MockFisherBlock'): + self.name = name + + def __eq__(self, other): + return isinstance(other, MockFisherBlock) and other.name == self.name + + def __hash__(self): + return hash(self.name) + + +class LayerParametersDictTest(test.TestCase): + + def testSetItem(self): + """Ensure insertion, contains, retrieval works for supported key types.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y0 = array_ops.constant(0) + y1 = array_ops.constant(0) + z0 = array_ops.constant(0) + z1 = array_ops.constant(0) + keys = [x, (y0, y1), [z0, z1]] + for key in keys: + lp_dict[key] = key + + for key in keys: + self.assertTrue(key in lp_dict) + self.assertEqual(lp_dict[key], key) + + def testSetItemOverlap(self): + """Ensure insertion fails if key overlaps with existing key.""" + with ops.Graph().as_default(): + lp_dict = layer_collection.LayerParametersDict() + + x = array_ops.constant(0) + y = array_ops.constant(0) + lp_dict[x] = 'value' + + with self.assertRaises(ValueError): + lp_dict[(x, y)] = 'value' + + # Ensure 'y' wasn't inserted. + self.assertTrue(x in lp_dict) + self.assertFalse(y in lp_dict) + + +class LayerCollectionTest(test.TestCase): + + def testLayerCollectionInit(self): + lc = layer_collection.LayerCollection() + self.assertEqual(0, len(lc.get_blocks())) + self.assertEqual(0, len(lc.get_factors())) + self.assertFalse(lc.losses) + + def testRegisterBlocks(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + lc.register_fully_connected( + array_ops.constant(1), array_ops.constant(2), array_ops.constant(3)) + lc.register_fully_connected( + array_ops.constant(1), + array_ops.constant(2), + array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_conv2d( + array_ops.constant(4), [1, 1, 1, 1], 'SAME', + array_ops.ones((1, 1, 1, 1)), array_ops.constant(3)) + lc.register_conv2d( + array_ops.constant(4), [1, 1, 1, 1], 'SAME', + array_ops.ones((1, 1, 1, 1)), array_ops.constant(3), + approx=layer_collection.APPROX_DIAGONAL_NAME) + lc.register_generic( + array_ops.constant(5), 16, approx=layer_collection.APPROX_FULL_NAME) + lc.register_generic( + array_ops.constant(6), + 16, + approx=layer_collection.APPROX_DIAGONAL_NAME) + + self.assertEqual(6, len(lc.get_blocks())) + + def testRegisterBlocksMultipleRegistrations(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.register_fully_connected(key, + array_ops.constant(2), array_ops.constant(3)) + with self.assertRaises(ValueError): + lc.register_generic(key, 16) + + def testRegisterSingleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('y', initializer=array_ops.constant(1,)): + '1' + } + lc.register_block(x, 'foo') + + def testShouldRegisterSingleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: '1'} + with self.assertRaises(ValueError): + lc.register_block(x, 'foo') + + def testRegisterSingleParamRegisteredInTuple(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + lc.register_block(x, 'foo') + self.assertEqual(set(['1']), set(lc.get_blocks())) + + def testRegisterTupleParamNotRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + variable_scope.get_variable('z', initializer=array_ops.constant(1,)): + '1' + } + + lc.register_block((x, y), 'foo') + self.assertEqual(set(['1', 'foo']), set(lc.get_blocks())) + + def testRegisterTupleParamRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y): '1'} + + with self.assertRaises(ValueError): + lc.register_block((x, y), 'foo') + + def testRegisterTupleParamRegisteredInSuperset(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, y, z): '1'} + + lc.register_block((x, y), 'foo') + self.assertEqual(set(['1']), set(lc.get_blocks())) + + def testRegisterTupleParamSomeRegistered(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {x: MockFisherBlock('1'), z: MockFisherBlock('2')} + + lc.register_block((x, y), MockFisherBlock('foo')) + self.assertEqual( + set([MockFisherBlock('2'), MockFisherBlock('foo')]), + set(lc.get_blocks())) + + def testRegisterTupleVarSomeRegisteredInOtherTuples(self): + x = variable_scope.get_variable('x', initializer=array_ops.constant(1,)) + y = variable_scope.get_variable('y', initializer=array_ops.constant(1,)) + z = variable_scope.get_variable('z', initializer=array_ops.constant(1,)) + w = variable_scope.get_variable('w', initializer=array_ops.constant(1,)) + lc = layer_collection.LayerCollection() + lc.fisher_blocks = {(x, z): '1', (z, w): '2'} + + with self.assertRaises(ValueError): + lc.register_block((x, y), 'foo') + + def testRegisterCategoricalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = linalg_ops.eye(2) + + lc = layer_collection.LayerCollection() + lc.register_categorical_predictive_distribution(logits, seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_categorical_predictive_distribution(logits, seed=200) + lc2.register_categorical_predictive_distribution(logits, seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testLossFunctionByName(self): + """Ensure loss functions can be identified by name.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function by name. + lc.register_categorical_predictive_distribution(logits, name='loss1') + self.assertEqual(1, len(lc.losses)) + + # Add logits to same loss function. + lc.register_categorical_predictive_distribution( + logits, name='loss1', reuse=True) + self.assertEqual(1, len(lc.losses)) + + # Add another new loss function. + lc.register_categorical_predictive_distribution(logits, name='loss2') + self.assertEqual(2, len(lc.losses)) + + def testLossFunctionWithoutName(self): + """Ensure loss functions get unique names if 'name' not specified.""" + with ops.Graph().as_default(): + logits = linalg_ops.eye(2) + lc = layer_collection.LayerCollection() + + # Create a new loss function with default names. + lc.register_categorical_predictive_distribution(logits) + lc.register_categorical_predictive_distribution(logits) + self.assertEqual(2, len(lc.losses)) + + def testCategoricalPredictiveDistributionMultipleMinibatches(self): + """Ensure multiple minibatches are registered.""" + with ops.Graph().as_default(): + batch_size = 3 + output_size = 2 + logits = array_ops.zeros([batch_size, output_size]) + targets = array_ops.ones([batch_size], dtype=dtypes.int32) + lc = layer_collection.LayerCollection() + + # Create a new loss function. + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1') + + # Can add when reuse=True + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=True) + + # Can add when reuse=VARIABLE_SCOPE and reuse=True there. + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + # Can't add when reuse=False + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, targets=targets, name='loss1', reuse=False) + + # Can't add when reuse=VARIABLE_SCOPE and reuse=False there. + with self.assertRaises(KeyError): + lc.register_categorical_predictive_distribution( + logits, + targets=targets, + name='loss1', + reuse=layer_collection.VARIABLE_SCOPE) + + self.assertEqual(len(lc.losses), 1) + loss = lc.losses[0] + + # Three successful registrations. + self.assertEqual(loss.params.shape.as_list(), + [3 * batch_size, output_size]) + self.assertEqual(loss.targets.shape.as_list(), [3 * batch_size]) + + def testRegisterCategoricalPredictiveDistributionBatchSize1(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + logits = random_ops.random_normal((1, 2)) + lc = layer_collection.LayerCollection() + + lc.register_categorical_predictive_distribution(logits, seed=200) + + def testRegisterCategoricalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + logits = array_ops.constant([[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([0, 1], dtype=dtypes.int32) + + lc.register_categorical_predictive_distribution(logits, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(1.6265233, single_loss) + + def testRegisterNormalPredictiveDistribution(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4]], dtype=dtypes.float32) + + lc = layer_collection.LayerCollection() + lc.register_normal_predictive_distribution(predictions, 1., seed=200) + single_loss = sess.run(lc.total_sampled_loss()) + + lc2 = layer_collection.LayerCollection() + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + lc2.register_normal_predictive_distribution(predictions, 1., seed=200) + double_loss = sess.run(lc2.total_sampled_loss()) + + self.assertAlmostEqual(2 * single_loss, double_loss) + + def testRegisterNormalPredictiveDistributionSpecifiedTargets(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + predictions = array_ops.constant( + [[1., 2.], [3., 4.]], dtype=dtypes.float32) + lc = layer_collection.LayerCollection() + targets = array_ops.constant([[3., 1.], [4., 2.]], dtype=dtypes.float32) + + lc.register_normal_predictive_distribution( + predictions, 2.**2, targets=targets) + single_loss = sess.run(lc.total_loss()) + self.assertAlmostEqual(7.6983433, single_loss) + + def ensureLayerReuseWorks(self, register_fn): + """Ensure the 'reuse' keyword argument function as intended. + + Args: + register_fn: function for registering a layer. Arguments are + layer_collection, reuse, and approx. + """ + # Fails on second if reuse=False. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=False) + + # Succeeds on second if reuse=True. + lc = layer_collection.LayerCollection() + register_fn(lc) + register_fn(lc, reuse=True) + + # Fails on second if reuse=VARIABLE_SCOPE and no variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with self.assertRaises(ValueError): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Succeeds on second if reuse=VARIABLE_SCOPE and variable reuse. + lc = layer_collection.LayerCollection() + register_fn(lc) + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): + register_fn(lc, reuse=layer_collection.VARIABLE_SCOPE) + + # Fails if block type changes. + lc = layer_collection.LayerCollection() + register_fn(lc, approx=layer_collection.APPROX_KRONECKER_NAME) + with self.assertRaises(ValueError): + register_fn(lc, approx=layer_collection.APPROX_DIAGONAL_NAME, reuse=True) + + # Fails if reuse requested but no FisherBlock exists. + lc = layer_collection.LayerCollection() + with self.assertRaises(KeyError): + register_fn(lc, reuse=True) + + def testRegisterFullyConnectedReuse(self): + """Ensure the 'reuse' works with register_fully_connected.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 10]) + outputs = array_ops.zeros([2, 5]) + params = ( + variable_scope.get_variable('w', [10, 5]), # + variable_scope.get_variable('b', [5])) + + def register_fn(lc, **kwargs): + lc.register_fully_connected( + params=params, inputs=inputs, outputs=outputs, **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testRegisterConv2dReuse(self): + """Ensure the 'reuse' works with register_conv2d.""" + with ops.Graph().as_default(): + inputs = array_ops.ones([2, 5, 5, 10]) + outputs = array_ops.zeros([2, 5, 5, 3]) + params = ( + variable_scope.get_variable('w', [1, 1, 10, 3]), # + variable_scope.get_variable('b', [3])) + + def register_fn(lc, **kwargs): + lc.register_conv2d( + params=params, + strides=[1, 1, 1, 1], + padding='SAME', + inputs=inputs, + outputs=outputs, + **kwargs) + + self.ensureLayerReuseWorks(register_fn) + + def testMakeOrGetFactor(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + lc = layer_collection.LayerCollection() + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue( + all([var.name.startswith('LayerCollection') for var in variables])) + + def testMakeOrGetFactorCustomScope(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + scope = 'Foo' + lc = layer_collection.LayerCollection(name=scope) + key = array_ops.constant(1) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, ((key,), 16)) + lc.make_or_get_factor(fisher_factors.FullFactor, + ((array_ops.constant(2),), 16)) + + self.assertEqual(2, len(lc.get_factors())) + variables = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + self.assertTrue(all([var.name.startswith(scope) for var in variables])) + + def testGetUseCountMap(self): + """Ensure get_use_count_map() sums 'num_registered_minibatches'.""" + lc = layer_collection.LayerCollection() + lc.fisher_blocks = { + 'a': MockFisherBlock(), + ('a', 'c'): MockFisherBlock(), + ('b', 'c'): MockFisherBlock() + } + use_count_map = lc.get_use_count_map() + self.assertDictEqual({'a': 4, 'b': 2, 'c': 4}, use_count_map) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py new file mode 100644 index 0000000000000000000000000000000000000000..87339cb059802ec8944d5d1ae4557ee34550cd60 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/loss_functions_test.py @@ -0,0 +1,101 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.loss_functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import loss_functions +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class InsertSliceInZerosTest(test.TestCase): + + def testBadShape(self): + bad_shaped_ones = array_ops.ones(shape=[1, 3]) # n.b. shape[1] != 1 + with self.assertRaises(ValueError): + loss_functions.insert_slice_in_zeros(bad_shaped_ones, 1, 42, 17) + + def test3d(self): + input_tensor = constant_op.constant([[[1, 2]], [[3, 4]]]) + expected_output_array = [[[1, 2], [0, 0]], [[3, 4], [0, 0]]] + op = loss_functions.insert_slice_in_zeros(input_tensor, 1, 2, 0) + with self.test_session() as sess: + actual_output_array = sess.run(op) + self.assertAllEqual(expected_output_array, actual_output_array) + + +class CategoricalLogitsNegativeLogProbLossTest(test.TestCase): + + def testSample(self): + """Ensure samples can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + sample = loss.sample(42) + sample = sess.run(sample) + self.assertEqual(sample.shape, (2,)) + + def testEvaluateOnTargets(self): + """Ensure log probability can be evaluated correctly.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + targets = np.asarray([2, 1]).astype(np.int32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits), targets=array_ops.constant(targets)) + neg_log_prob = loss.evaluate() + neg_log_prob = sess.run(neg_log_prob) + + # Calculate explicit log probability of targets. + probs = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True) + log_probs = np.log([ + probs[0, targets[0]], # + probs[1, targets[1]] + ]) + expected_log_prob = np.sum(log_probs) + + self.assertAllClose(neg_log_prob, -expected_log_prob) + + def testEvaluateOnSample(self): + """Ensure log probability of a sample can be drawn.""" + with ops.Graph().as_default(), self.test_session() as sess: + logits = np.asarray([ + [0., 0., 0.], # + [1., -1., 0.] + ]).astype(np.float32) + loss = loss_functions.CategoricalLogitsNegativeLogProbLoss( + array_ops.constant(logits)) + neg_log_prob = loss.evaluate_on_sample(42) + + # Simply ensure this doesn't crash. As the output is random, it's + # difficult to say if the output is correct or not... + neg_log_prob = sess.run(neg_log_prob) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b20a70e4ca3ec2d65058df2ab8a9c11f8303e714 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/op_queue_test.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.op_queue.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import op_queue +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.platform import test + + +class OpQueueTest(test.TestCase): + + def testNextOp(self): + """Ensures all ops get selected eventually.""" + with tf_ops.Graph().as_default(): + ops = [ + math_ops.add(1, 2), + math_ops.subtract(1, 2), + math_ops.reduce_mean([1, 2]), + ] + queue = op_queue.OpQueue(ops, seed=0) + + with self.test_session() as sess: + # Ensure every inv update op gets selected. + selected_ops = set([queue.next_op(sess) for _ in ops]) + self.assertEqual(set(ops), set(selected_ops)) + + # Ensure additional calls don't create any new ops. + selected_ops.add(queue.next_op(sess)) + self.assertEqual(set(ops), set(selected_ops)) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9325aa1b7325fa9cf546d66e6505affa1af7db4d --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/optimizer_test.py @@ -0,0 +1,204 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import layer_collection as lc +from tensorflow.contrib.kfac.python.ops import optimizer +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.platform import test + + +def dummy_layer_collection(): + lcoll = lc.LayerCollection() + dummy = array_ops.constant([1., 2.]) + lcoll.register_categorical_predictive_distribution(logits=dummy) + return lcoll + + +class OptimizerTest(test.TestCase): + + def testOptimizerInitInvalidMomentumRegistration(self): + with self.assertRaises(ValueError): + optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, lc.LayerCollection(), momentum_type='foo') + + def testOptimizerInit(self): + with ops.Graph().as_default(): + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + + def testSquaredFisherNorm(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + opt = optimizer.KfacOptimizer(0.1, 0.2, 0.3, dummy_layer_collection()) + sq_norm = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(174., sess.run(sq_norm), places=5) + + def testUpdateClipCoeff(self): + with ops.Graph().as_default(), self.test_session() as sess: + grads_and_vars = [(array_ops.constant([[1., 2.], [3., 4.]]), None), + (array_ops.constant([[2., 3.], [4., 5.]]), None)] + pgrads_and_vars = [(array_ops.constant([[3., 4.], [5., 6.]]), None), + (array_ops.constant([[7., 8.], [9., 10.]]), None)] + lrate = 0.1 + + # Note: without rescaling, the squared Fisher norm of the update + # is 1.74 + + # If the update already satisfies the norm constraint, there should + # be no rescaling. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=10.) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + self.assertAlmostEqual(1., sess.run(coeff), places=5) + + # If the update violates the constraint, it should be rescaled to + # be on the constraint boundary. + opt = optimizer.KfacOptimizer( + lrate, 0.2, 0.3, dummy_layer_collection(), norm_constraint=0.5) + coeff = opt._update_clip_coeff(grads_and_vars, pgrads_and_vars) + sq_norm_pgrad = opt._squared_fisher_norm(grads_and_vars, pgrads_and_vars) + sq_norm_update = lrate**2 * coeff**2 * sq_norm_pgrad + self.assertAlmostEqual(0.5, sess.run(sq_norm_update), places=5) + + def testComputeUpdateStepsRegular(self): + # TODO(olganw): implement this. + pass + + def testComputeUpdateStepsAdam(self): + # TODO(olganw): implement this. + pass + + def testUpdateVelocities(self): + with ops.Graph().as_default(), self.test_session() as sess: + layers = lc.LayerCollection() + layers.register_categorical_predictive_distribution( + array_ops.constant([1.0])) + opt = optimizer.KfacOptimizer( + 0.1, 0.2, 0.3, layers, momentum=0.5, momentum_type='regular') + x = variable_scope.get_variable('x', initializer=array_ops.ones((2, 2))) + y = variable_scope.get_variable( + 'y', initializer=array_ops.ones((2, 2)) * 2) + vec1 = array_ops.ones((2, 2)) * 3 + vec2 = array_ops.ones((2, 2)) * 4 + + model_vars = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + update_op = opt._update_velocities([(vec1, x), (vec2, y)], 0.5) + opt_vars = [ + v for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + if v not in model_vars + ] + + sess.run(tf_variables.global_variables_initializer()) + old_opt_vars = sess.run(opt_vars) + + # Optimizer vars start out at 0. + for opt_var in old_opt_vars: + self.assertAllEqual(sess.run(array_ops.zeros_like(opt_var)), opt_var) + + sess.run(update_op) + new_opt_vars = sess.run(opt_vars) + # After one update, the velocities are equal to the vectors. + for vec, opt_var in zip([vec1, vec2], new_opt_vars): + self.assertAllEqual(sess.run(vec), opt_var) + + sess.run(update_op) + final_opt_vars = sess.run(opt_vars) + for first, second in zip(new_opt_vars, final_opt_vars): + self.assertFalse(np.equal(first, second).all()) + + def testApplyGradients(self): + with ops.Graph().as_default(), self.test_session() as sess: + layer_collection = lc.LayerCollection() + + inputs = array_ops.ones((2, 1)) * 2 + weights_val = np.ones((1, 1), dtype=np.float32) * 3. + weights = variable_scope.get_variable( + 'w', initializer=array_ops.constant(weights_val)) + bias = variable_scope.get_variable( + 'b', initializer=init_ops.zeros_initializer(), shape=(1, 1)) + output = math_ops.matmul(inputs, weights) + bias + + layer_collection.register_fully_connected((weights, bias), inputs, output) + + logits = math_ops.tanh(output) + targets = array_ops.constant([[0.], [1.]]) + output = math_ops.reduce_mean( + nn.softmax_cross_entropy_with_logits(logits=logits, labels=targets)) + + layer_collection.register_categorical_predictive_distribution(logits) + + opt = optimizer.KfacOptimizer( + 0.1, + 0.2, + 0.3, + layer_collection, + momentum=0.5, + momentum_type='regular') + grads_and_vars = opt.compute_gradients(output, [weights, bias]) + all_vars = [grad_and_var[1] for grad_and_var in grads_and_vars] + + op = opt.apply_gradients(grads_and_vars) + + sess.run(tf_variables.global_variables_initializer()) + old_vars = sess.run(all_vars) + sess.run(op) + new_vars = sess.run(all_vars) + + for old_var, new_var in zip(old_vars, new_vars): + self.assertNotEqual(old_var, new_var) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..55fe38e3e9aab2dbd70a45cdc8fa0c208b036db0 --- /dev/null +++ b/tensorflow/contrib/kfac/python/kernel_tests/utils_test.py @@ -0,0 +1,270 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for tf.contrib.kfac.utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import numpy.random as npr + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.platform import test + + +class SequenceDictTest(test.TestCase): + + def testSequenceDictInit(self): + seq_dict = utils.SequenceDict() + self.assertFalse(seq_dict._dict) + + def testSequenceDictInitWithIterable(self): + reg_dict = {'a': 'foo', 'b': 'bar'} + itr = zip(reg_dict.keys(), reg_dict.values()) + seq_dict = utils.SequenceDict(itr) + self.assertEqual(reg_dict, seq_dict._dict) + + def testGetItemSingleKey(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual('foo', seq_dict['a']) + + def testGetItemMultipleKeys(self): + seq_dict = utils.SequenceDict({'a': 'foo', 'b': 'bar'}) + self.assertEqual(['foo', 'bar'], seq_dict[('a', 'b')]) + + def testSetItemSingleKey(self): + seq_dict = utils.SequenceDict() + seq_dict['a'] = 'foo' + self.assertEqual([('a', 'foo')], seq_dict.items()) + + def testSetItemMultipleKeys(self): + seq_dict = utils.SequenceDict() + keys = ('a', 'b', 'c') + values = ('foo', 'bar', 'baz') + seq_dict[keys] = values + self.assertItemsEqual(list(zip(keys, values)), seq_dict.items()) + + +class SubGraphTest(test.TestCase): + + def testBasicGraph(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + self.assertFalse(sub_graph.is_member(d)) + + def testRepeatedAdds(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + a # note that a appears twice in this graph + sub_graph = utils.SubGraph((c,)) + self.assertTrue(sub_graph.is_member(a)) + self.assertTrue(sub_graph.is_member(b)) + self.assertTrue(sub_graph.is_member(c)) + + def testFilterList(self): + a = array_ops.constant([[1., 2.], [3., 4.]]) + b = array_ops.constant([[5., 6.], [7., 8.]]) + c = a + b + d = a * b + sub_graph = utils.SubGraph((c,)) + input_list = [b, d] + filtered_list = sub_graph.filter_list(input_list) + self.assertEqual(filtered_list, [b]) + + +class UtilsTest(test.TestCase): + + def _fully_connected_layer_params(self): + weights_part = array_ops.constant([[1., 2.], [4., 3.]]) + bias_part = array_ops.constant([1., 2.]) + return (weights_part, bias_part) + + def _conv_layer_params(self): + weights_shape = 2, 2, 3, 4 + biases_shape = weights_shape[-1:] + weights = array_ops.constant(npr.RandomState(0).randn(*weights_shape)) + biases = array_ops.constant(npr.RandomState(1).randn(*biases_shape)) + return (weights, biases) + + def testFullyConnectedLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([3, 2], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([[1., 2.], [4., 3.], [1., 2.]])) + + def testFullyConnectedLayerParamsTensorToMat2d(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + layer_params = self._fully_connected_layer_params() + output = utils.layer_params_to_mat2d(layer_params[0]) + self.assertListEqual([2, 2], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([[1., 2.], [4., 3.]])) + + def testConvLayerParamsTupleToMat2d(self): + with ops.Graph().as_default(): + random_seed.set_random_seed(200) + layer_params = self._conv_layer_params() + output = utils.layer_params_to_mat2d(layer_params) + self.assertListEqual([2 * 2 * 3 + 1, 4], output.get_shape().as_list()) + + def testKron(self): + with ops.Graph().as_default(), self.test_session() as sess: + mat1 = np.array([[1., 2.], [3., 4.]]) + mat2 = np.array([[5., 6.], [7., 8.]]) + mat1_tf = array_ops.constant(mat1) + mat2_tf = array_ops.constant(mat2) + ans_tf = sess.run(utils.kronecker_product(mat1_tf, mat2_tf)) + ans_np = np.kron(mat1, mat2) + self.assertAllClose(ans_tf, ans_np) + + def testMat2dToFullyConnectedLayerParamsTuple(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params() + mat2d = array_ops.constant([[5., 4.], [3., 2.], [1., 0.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[5., 4.], [3., 2.]])) + self.assertAllClose(b, np.array([1., 0.])) + + def testMat2dToFullyConnectedLayerParamsTensor(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + vector_template = self._fully_connected_layer_params()[0] + mat2d = array_ops.constant([[5., 4.], [3., 2.]]) + + output = sess.run(utils.mat2d_to_layer_params(vector_template, mat2d)) + + self.assertAllClose(output, np.array([[5., 4.], [3., 2.]])) + + def testTensorsToColumn(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + output = utils.tensors_to_column(vector) + self.assertListEqual([4, 1], output.get_shape().as_list()) + self.assertAllClose(sess.run(output), np.array([0., 1., 2., 3.])[:, None]) + + vector = self._fully_connected_layer_params() + output = utils.tensors_to_column(vector) + self.assertListEqual([6, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), np.array([1., 2., 4., 3., 1., 2.])[:, None]) + + vector = list(vector) + vector.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + + output = utils.tensors_to_column(vector) + self.assertListEqual([10, 1], output.get_shape().as_list()) + self.assertAllClose( + sess.run(output), + np.array([1., 2., 4., 3., 1., 2., 6., 7., 8., 9.])[:, None]) + + def testColumnToTensors(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + + vector_template = array_ops.constant(np.array([[0., 1.], [2., 3.]])) + colvec = array_ops.constant(np.arange(4.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertAllClose(output, np.array([[0., 1.], [2., 3.]])) + + vector_template = self._fully_connected_layer_params() + colvec = array_ops.constant(np.arange(6.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 2) + a, b = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + + vector_template = list(vector_template) + vector_template.append(array_ops.constant([[6.], [7.], [8.], [9.]])) + colvec = array_ops.constant(np.arange(10.)[:, None]) + output = sess.run(utils.column_to_tensors(vector_template, colvec)) + self.assertIsInstance(output, tuple) + self.assertEqual(len(output), 3) + a, b, c = output + self.assertAllClose(a, np.array([[0., 1.], [2., 3.]])) + self.assertAllClose(b, np.array([4., 5.])) + self.assertAllClose(c, np.array([[6.], [7.], [8.], [9.]])) + + def testComputePi(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + left_factor = array_ops.diag([1., 2., 0., 1.]) + right_factor = array_ops.ones([2., 2.]) + + # pi is the sqrt of the left trace norm divided by the right trace norm + pi = utils.compute_pi(left_factor, right_factor) + + pi_val = sess.run(pi) + self.assertEqual(1., pi_val) + + def testPosDefInvCholesky(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_cholesky(array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + def testPosDefInvMatrixInverse(self): + with ops.Graph().as_default(), self.test_session() as sess: + random_seed.set_random_seed(200) + npr.seed(0) + square = lambda x: np.dot(x, x.T) + + size = 3 + x = square(npr.randn(size, size)) + damp = 0.1 + identity = linalg_ops.eye(size, dtype=dtypes.float64) + + tf_inv = utils.posdef_inv_matrix_inverse( + array_ops.constant(x), identity, damp) + np_inv = np.linalg.inv(x + damp * np.eye(size)) + self.assertAllClose(sess.run(tf_inv), np_inv) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..de4b8920b849dbf2117657de6e7c26f94f4d0363 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/BUILD @@ -0,0 +1,248 @@ +package(default_visibility = [ + "//tensorflow/contrib/kfac:__pkg__", + "//tensorflow/contrib/kfac/python/kernel_tests:__pkg__", +]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_library( + name = "fisher_blocks", + srcs = ["fisher_blocks.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_blocks_lib", + srcs = ["fisher_blocks_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_factors", + srcs = ["fisher_factors.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:special_math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + +py_library( + name = "fisher_factors_lib", + srcs = ["fisher_factors_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_factors", + "//tensorflow/python:util", + ], +) + +py_library( + name = "loss_functions", + srcs = ["loss_functions.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:tensor_shape", + "//tensorflow/python/ops/distributions", + "@six_archive//:six", + ], +) + +py_library( + name = "loss_functions_lib", + srcs = ["loss_functions_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":loss_functions", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products", + srcs = ["curvature_matrix_vector_products.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:gradients", + "//tensorflow/python:math_ops", + "//tensorflow/python:util", + ], +) + +py_library( + name = "curvature_matrix_vector_products_lib", + srcs = ["curvature_matrix_vector_products_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + "//tensorflow/python:util", + ], +) + +py_library( + name = "layer_collection", + srcs = ["layer_collection.py"], + srcs_version = "PY2AND3", + deps = [ + ":fisher_blocks", + ":loss_functions", + ":utils", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "@six_archive//:six", + ], +) + +py_library( + name = "layer_collection_lib", + srcs = ["layer_collection_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":layer_collection", + "//tensorflow/python:util", + ], +) + +py_library( + name = "kfac_optimizer", + srcs = [ + "optimizer.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":curvature_matrix_vector_products", + ":fisher_estimator", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "kfac_optimizer_lib", + srcs = [ + "optimizer_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":kfac_optimizer", + "//tensorflow/python:util", + ], +) + +py_library( + name = "fisher_estimator", + srcs = [ + "estimator.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:util", + "//third_party/py/numpy", + ], +) + +py_library( + name = "fisher_estimator_lib", + srcs = [ + "estimator_lib.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":fisher_estimator", + "//tensorflow/python:util", + ], +) + +py_library( + name = "utils", + srcs = ["utils.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:gradients", + "//tensorflow/python:linalg_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:random_ops", + "//third_party/py/numpy", + ], +) + +py_library( + name = "utils_lib", + srcs = ["utils_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":utils", + "//tensorflow/python:util", + ], +) + +py_library( + name = "op_queue", + srcs = ["op_queue.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/data/python/ops:dataset_ops", + "//tensorflow/python:framework_ops", + ], +) + +py_library( + name = "op_queue_lib", + srcs = ["op_queue_lib.py"], + srcs_version = "PY2AND3", + deps = [ + ":op_queue", + "//tensorflow/python:util", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py new file mode 100644 index 0000000000000000000000000000000000000000..21b5cde9b931a95110c9a5fd7930a3a4ee74b207 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products.py @@ -0,0 +1,183 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import math_ops +from tensorflow.python.util import nest + + +class CurvatureMatrixVectorProductComputer(object): + """Class for computing matrix-vector products for Fishers, GGNs and Hessians. + + In other words we compute M*v where M is the matrix, v is the vector, and + * refers to standard matrix/vector multiplication (not element-wise + multiplication). + + The matrices are defined in terms of some differential quantity of the total + loss function with respect to a provided list of tensors ("wrt_tensors"). + For example, the Fisher associated with a log-prob loss w.r.t. the + parameters. + + The 'vecs' argument to each method are lists of tensors that must be the + size as the corresponding ones from "wrt_tensors". They represent + the vector being multiplied. + + "factors" of the matrix M are defined as matrices B such that B*B^T = M. + Methods that multiply by the factor B take a 'loss_inner_vecs' argument + instead of 'vecs', which must be a list of tensors with shapes given by the + corresponding XXX_inner_shapes property. + + Note that matrix-vector products are not normalized by the batch size, nor + are any damping terms added to the results. These things can be easily + applied externally, if desired. + + See for example: www.cs.utoronto.ca/~jmartens/docs/HF_book_chapter.pdf + and https://arxiv.org/abs/1412.1193 for more information about the + generalized Gauss-Newton, Fisher, etc., and how to compute matrix-vector + products. + """ + + def __init__(self, losses, wrt_tensors): + """Create a CurvatureMatrixVectorProductComputer object. + + Args: + losses: A list of LossFunction instances whose sum defines the total loss. + wrt_tensors: A list of Tensors to compute the differential quantities + (defining the matrices) with respect to. See class description for more + info. + """ + self._losses = losses + self._inputs_to_losses = list(loss.inputs for loss in losses) + self._inputs_to_losses_flat = nest.flatten(self._inputs_to_losses) + self._wrt_tensors = wrt_tensors + + @property + def _total_loss(self): + return math_ops.add_n(tuple(loss.evaluate() for loss in self._losses)) + + # Jacobian multiplication functions: + def _multiply_jacobian(self, vecs): + """Multiply vecs by the Jacobian of losses.""" + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + jacobian_vecs_flat = utils.fwd_gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_xs=vecs, + stop_gradients=self._wrt_tensors) + return nest.pack_sequence_as(self._inputs_to_losses, jacobian_vecs_flat) + + def _multiply_jacobian_transpose(self, loss_vecs): + """Multiply vecs by the transpose Jacobian of losses.""" + loss_vecs_flat = nest.flatten(loss_vecs) + # We stop gradients at wrt_tensors to produce partial derivatives (which is + # what we want for Jacobians). + return gradients_impl.gradients( + self._inputs_to_losses_flat, self._wrt_tensors, grad_ys=loss_vecs_flat, + stop_gradients=self._wrt_tensors) + + # Losses Fisher/Hessian multiplication functions: + def _multiply_loss_fisher(self, loss_vecs): + """Multiply loss_vecs by Fisher of total loss.""" + return tuple( + loss.multiply_fisher(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_fisher_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Fisher of total loss.""" + return tuple( + loss.multiply_fisher_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian(self, loss_vecs): + """Multiply loss_vecs by Hessian of total loss.""" + return tuple( + loss.multiply_hessian(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + def _multiply_loss_hessian_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor(loss_vec) + for loss, loss_vec in zip(self._losses, loss_inner_vecs)) + + def _multiply_loss_hessian_factor_transpose(self, loss_vecs): + """Multiply loss_vecs by transpose factor of Hessian of total loss.""" + return tuple( + loss.multiply_hessian_factor_transpose(loss_vec) + for loss, loss_vec in zip(self._losses, loss_vecs)) + + # Matrix-vector product functions: + def multiply_fisher(self, vecs): + """Multiply vecs by Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_fisher_jacobian_vecs = self._multiply_loss_fisher(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_fisher_jacobian_vecs) + + def multiply_fisher_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of Fisher of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_fisher_factor_transpose(jacobian_vecs) + + def multiply_fisher_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of Fisher of total loss.""" + fisher_factor_transpose_vecs = self._multiply_loss_fisher_factor_transpose( + loss_inner_vecs) + return self._multiply_jacobian_transpose(fisher_factor_transpose_vecs) + + def multiply_hessian(self, vecs): + """Multiply vecs by Hessian of total loss.""" + return gradients_impl.gradients( + gradients_impl.gradients(self._total_loss, self._wrt_tensors), + self._wrt_tensors, + grad_ys=vecs) + + def multiply_generalized_gauss_newton(self, vecs): + """Multiply vecs by generalized Gauss-Newton of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + loss_hessian_jacobian_vecs = self._multiply_loss_hessian(jacobian_vecs) + return self._multiply_jacobian_transpose(loss_hessian_jacobian_vecs) + + def multiply_generalized_gauss_newton_factor_transpose(self, vecs): + """Multiply vecs by transpose of factor of GGN of total loss.""" + jacobian_vecs = self._multiply_jacobian(vecs) + return self._multiply_loss_hessian_factor_transpose(jacobian_vecs) + + def multiply_generalized_gauss_newton_factor(self, loss_inner_vecs): + """Multiply loss_inner_vecs by factor of GGN of total loss.""" + hessian_factor_transpose_vecs = ( + self._multiply_loss_hessian_factor_transpose(loss_inner_vecs)) + return self._multiply_jacobian_transpose(hessian_factor_transpose_vecs) + + # Shape properties for multiply_XXX_factor methods: + @property + def fisher_factor_inner_shapes(self): + """Shapes required by multiply_fisher_factor.""" + return tuple(loss.fisher_factor_inner_shape for loss in self._losses) + + @property + def generalized_gauss_newton_factor_inner_shapes(self): + """Shapes required by multiply_generalized_gauss_newton_factor.""" + return tuple(loss.hessian_factor_inner_shape for loss in self._losses) diff --git a/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..6e8c6404dcba0970785a2c8358cb4e2356e45b0e --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/curvature_matrix_vector_products_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Curvature matrix-vector multiplication.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.curvature_matrix_vector_products import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'CurvatureMatrixVectorProductComputer', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/estimator.py b/tensorflow/contrib/kfac/python/ops/estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..ce4e776324bbde1b8f214d89daa876032d8a21ff --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator.py @@ -0,0 +1,279 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.util import nest + + +class FisherEstimator(object): + """Fisher estimator class supporting various approximations of the Fisher.""" + + def __init__(self, + variables, + cov_ema_decay, + damping, + layer_collection, + estimation_mode="gradients"): + """Create a FisherEstimator object. + + Args: + variables: A list of the variables for which to estimate the Fisher. This + must match the variables registered in layer_collection (if it is not + None). + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: The damping factor used to stabilize training due to errors in + the local approximation with the Fisher information matrix, and to + regularize the update direction by making it closer to the gradient. + (Higher damping means the update looks more like a standard gradient + update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the fisher + blocks, kronecker factors, and losses associated with the + graph. + estimation_mode: The type of estimator to use for the Fishers. Can be + 'gradients', 'empirical', 'curvature_propagation', or 'exact'. + (Default: 'gradients'). 'gradients' is the basic estimation approach + from the original K-FAC paper. 'empirical' computes the 'empirical' + Fisher information matrix (which uses the data's distribution for the + targets, as opposed to the true Fisher which uses the model's + distribution) and requires that each registered loss have specified + targets. 'curvature_propagation' is a method which estimates the + Fisher using self-products of random 1/-1 vectors times "half-factors" + of the Fisher, as described here: https://arxiv.org/abs/1206.6464 . + Finally, 'exact' is the obvious generalization of Curvature + Propagation to compute the exact Fisher (modulo any additional + diagonal or Kronecker approximations) by looping over one-hot vectors + for each coordinate of the output instead of using 1/-1 vectors. It + is more expensive to compute than the other three options by a factor + equal to the output dimension, roughly speaking. + + Raises: + ValueError: If no losses have been registered with layer_collection. + """ + + self._variables = variables + self._damping = damping + self._estimation_mode = estimation_mode + self._layers = layer_collection + self._layers.create_subgraph() + self._check_registration(variables) + self._gradient_fns = { + "gradients": self._get_grads_lists_gradients, + "empirical": self._get_grads_lists_empirical, + "curvature_prop": self._get_grads_lists_curvature_prop, + "exact": self._get_grads_lists_exact + } + setup = self._setup(cov_ema_decay) + self.cov_update_op, self.inv_update_op, self.inv_updates_dict = setup + + @property + def variables(self): + return self._variables + + @property + def damping(self): + return self._damping + + def _apply_transformation(self, vecs_and_vars, transform): + """Applies an block-wise transformation to the corresponding vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + transform: A function of the form f(fb, vec), where vec is the vector + to transform and fb is its corresponding block in the matrix, that + returns the transformed vector. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + + vecs = utils.SequenceDict((var, vec) for vec, var in vecs_and_vars) + + trans_vecs = utils.SequenceDict() + + for params, fb in self._layers.fisher_blocks.items(): + trans_vecs[params] = transform(fb, vecs[params]) + + return [(trans_vecs[var], var) for _, var in vecs_and_vars] + + def multiply_inverse(self, vecs_and_vars): + """Multiplies the vecs by the corresponding (damped) inverses of the blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + + return self._apply_transformation(vecs_and_vars, + lambda fb, vec: fb.multiply_inverse(vec)) + + def multiply(self, vecs_and_vars): + """Multiplies the vectors by the corresponding (damped) blocks. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + + Returns: + A list of (transformed vector, var) pairs in the same order as + vecs_and_vars. + """ + + return self._apply_transformation(vecs_and_vars, + lambda fb, vec: fb.multiply(vec)) + + def _check_registration(self, variables): + """Checks that all variable uses have been registered properly. + + Args: + variables: List of variables. + + Raises: + ValueError: If any registered variables are not included in the list. + ValueError: If any variable in the list is not registered. + ValueError: If any variable in the list is registered with the wrong + number of "uses" in the subgraph recorded (vs the number of times that + variable is actually used in the subgraph). + """ + # Note that overlapping parameters (i.e. those that share variables) will + # be caught by layer_collection.LayerParametersDict during registration. + + reg_use_map = self._layers.get_use_count_map() + + error_messages = [] + + for var in variables: + total_uses = self._layers.subgraph.variable_uses(var) + reg_uses = reg_use_map[var] + + if reg_uses == 0: + error_messages.append("Variable {} not registered.".format(var)) + elif (not math.isinf(reg_uses)) and reg_uses != total_uses: + error_messages.append( + "Variable {} registered with wrong number of uses ({} " + "registrations vs {} uses).".format(var, reg_uses, total_uses)) + + num_get_vars = len(reg_use_map) + + if num_get_vars > len(variables): + error_messages.append("{} registered variables were not included in list." + .format(num_get_vars - len(variables))) + + if error_messages: + error_messages = [ + "Found the following errors with variable registration:" + ] + error_messages + raise ValueError("\n\t".join(error_messages)) + + def _setup(self, cov_ema_decay): + """Sets up the various operations. + + Args: + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + + Returns: + A triple (covs_update_op, invs_update_op, inv_updates_dict), where + covs_update_op is the grouped Op to update all the covariance estimates, + invs_update_op is the grouped Op to update all the inverses, and + inv_updates_dict is a dict mapping Op names to individual inverse updates. + + Raises: + ValueError: If estimation_mode was improperly specified at construction. + """ + fisher_blocks_list = self._layers.get_blocks() + tensors_to_compute_grads = [ + fb.tensors_to_compute_grads() for fb in fisher_blocks_list + ] + + try: + grads_lists = self._gradient_fns[self._estimation_mode]( + tensors_to_compute_grads) + except KeyError: + raise ValueError("Unrecognized value {} for estimation_mode.".format( + self._estimation_mode)) + + for grads_list, fb in zip(grads_lists, fisher_blocks_list): + fb.instantiate_factors(grads_list, self.damping) + + cov_updates = [ + factor.make_covariance_update_op(cov_ema_decay) + for factor in self._layers.get_factors() + ] + inv_updates = {op.name: op for op in self._get_all_inverse_update_ops()} + + return control_flow_ops.group(*cov_updates), control_flow_ops.group( + *inv_updates.values()), inv_updates + + def _get_all_inverse_update_ops(self): + for factor in self._layers.get_factors(): + for op in factor.make_inverse_update_ops(): + yield op + + def _get_grads_lists_gradients(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_sampled_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_empirical(self, tensors): + grads_flat = gradients_impl.gradients(self._layers.total_loss(), + nest.flatten(tensors)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_transformed_random_signs(self): + transformed_random_signs = [] + for loss in self._layers.losses: + transformed_random_signs.append( + loss.multiply_fisher_factor( + utils.generate_random_signs(loss.fisher_factor_inner_shape))) + return transformed_random_signs + + def _get_grads_lists_curvature_prop(self, tensors): + loss_inputs = list(loss.inputs for loss in self._layers.losses) + transformed_random_signs = self._get_transformed_random_signs() + grads_flat = gradients_impl.gradients( + nest.flatten(loss_inputs), + nest.flatten(tensors), + grad_ys=nest.flatten(transformed_random_signs)) + grads_all = nest.pack_sequence_as(tensors, grads_flat) + return tuple((grad,) for grad in grads_all) + + def _get_grads_lists_exact(self, tensors): + # Loop over all coordinates of all losses. + grads_all = [] + for loss in self._layers.losses: + for index in np.ndindex(*loss.fisher_factor_inner_static_shape[1:]): + transformed_one_hot = loss.multiply_fisher_factor_replicated_one_hot( + index) + grads_flat = gradients_impl.gradients( + loss.inputs, nest.flatten(tensors), grad_ys=transformed_one_hot) + grads_all.append(nest.pack_sequence_as(tensors, grads_flat)) + return zip(*grads_all) diff --git a/tensorflow/contrib/kfac/python/ops/estimator_lib.py b/tensorflow/contrib/kfac/python/ops/estimator_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..33c969650615bf8e439c2f669b4a1efaf2f565ff --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/estimator_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Defines the high-level Fisher estimator class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.estimator import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherEstimator', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..a6fdf01fe7d06a1719aef1f3c329a5587add651a --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks.py @@ -0,0 +1,722 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions. + +This library contains classes for estimating blocks in a model's Fisher +Information matrix. Suppose one has a model that parameterizes a posterior +distribution over 'y' given 'x' with parameters 'params', p(y | x, params). Its +Fisher Information matrix is given by, + + F(params) = E[ v(x, y, params) v(x, y, params)^T ] + +where, + + v(x, y, params) = (d / d params) log p(y | x, params) + +and the expectation is taken with respect to the data's distribution for 'x' and +the model's posterior distribution for 'y', + + x ~ p(x) + y ~ p(y | x, params) + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.contrib.kfac.python.ops import fisher_factors +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops + +# For blocks corresponding to convolutional layers, or any type of block where +# the parameters can be thought of as being replicated in time or space, +# we want to adjust the scale of the damping by +# damping /= num_replications ** NORMALIZE_DAMPING_POWER +NORMALIZE_DAMPING_POWER = 1.0 + + +def set_global_constants(normalize_damping_power=None): + """Sets various global constants used by the classes in this module.""" + global NORMALIZE_DAMPING_POWER + + if normalize_damping_power is not None: + NORMALIZE_DAMPING_POWER = normalize_damping_power + + +@six.add_metaclass(abc.ABCMeta) +class FisherBlock(object): + """Abstract base class for objects modeling approximate Fisher matrix blocks. + + Subclasses must implement multiply_inverse(), instantiate_factors(), and + tensors_to_compute_grads() methods. + """ + + def __init__(self, layer_collection): + self._layer_collection = layer_collection + + @abc.abstractmethod + def instantiate_factors(self, grads_list, damping): + """Creates and registers the component factors of this Fisher block. + + Args: + grads_list: A list gradients (each a Tensor or tuple of Tensors) with + respect to the tensors returned by tensors_to_compute_grads() that + are to be used to estimate the block. + damping: The damping factor (float or Tensor). + """ + pass + + @abc.abstractmethod + def multiply_inverse(self, vector): + """Multiplies the vector by the (damped) inverse of the block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) inverse of the block. + """ + pass + + @abc.abstractmethod + def multiply(self, vector): + """Multiplies the vector by the (damped) block. + + Args: + vector: The vector (a Tensor or tuple of Tensors) to be multiplied. + + Returns: + The vector left-multiplied by the (damped) block. + """ + pass + + @abc.abstractmethod + def tensors_to_compute_grads(self): + """Returns the Tensor(s) with respect to which this FisherBlock needs grads. + """ + pass + + @abc.abstractproperty + def num_registered_minibatches(self): + """Number of minibatches registered for this FisherBlock. + + Typically equal to the number of towers in a multi-tower setup. + """ + pass + + +class FullFB(FisherBlock): + """FisherBlock using a full matrix estimate (no approximations). + + FullFB uses a full matrix estimate (no approximations), and should only ever + be used for very low dimensional parameters. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a FullFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._batch_sizes = [] + self._params = params + + super(FullFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping = damping + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullFactor, (grads_list, self._batch_size)) + self._factor.register_damped_inverse(damping) + + def multiply_inverse(self, vector): + inverse = self._factor.get_inverse(self._damping) + out_flat = math_ops.matmul(inverse, utils.tensors_to_column(vector)) + return utils.column_to_tensors(vector, out_flat) + + def multiply(self, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = ( + math_ops.matmul(self._factor.get_cov(), vector_flat) + + self._damping * vector_flat) + return utils.column_to_tensors(vector, out_flat) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block.""" + return self._factor.get_cov() + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_minibatches(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +class NaiveDiagonalFB(FisherBlock): + """FisherBlock using a diagonal matrix approximation. + + This type of approximation is generically applicable but quite primitive. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, layer_collection, params): + """Creates a NaiveDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters of this layer (Tensor or tuple of Tensors). + """ + self._params = params + self._batch_sizes = [] + + super(NaiveDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + self._damping = damping + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.NaiveDiagonalFactor, (grads_list, self._batch_size)) + + def multiply_inverse(self, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = vector_flat / (self._factor.get_cov() + self._damping) + return utils.column_to_tensors(vector, out_flat) + + def multiply(self, vector): + vector_flat = utils.tensors_to_column(vector) + out_flat = vector_flat * (self._factor.get_cov() + self._damping) + return utils.column_to_tensors(vector, out_flat) + + def full_fisher_block(self): + return array_ops.diag(array_ops.reshape(self._factor.get_cov(), (-1,))) + + def tensors_to_compute_grads(self): + return self._params + + def register_additional_minibatch(self, batch_size): + """Register an additional minibatch. + + Args: + batch_size: The batch size, used in the covariance estimator. + """ + self._batch_sizes.append(batch_size) + + @property + def num_registered_minibatches(self): + return len(self._batch_sizes) + + @property + def _batch_size(self): + return math_ops.reduce_sum(self._batch_sizes) + + +class FullyConnectedDiagonalFB(FisherBlock): + """FisherBlock for fully-connected (dense) layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a fully + connected layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of + squares" estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ] + + Consider fully connected layer in this model with (unshared) weight matrix + 'w'. For an example 'x' that produces layer inputs 'a' and output + preactivations 's', + + v(x, y, w) = vec( a (d loss / d s)^T ) + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._inputs = [] + self._outputs = [] + self._has_bias = has_bias + + super(FullyConnectedDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._damping = damping + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.FullyConnectedDiagonalFactor, + (inputs, grads_list, self._has_bias)) + + def multiply_inverse(self, vector): + """Approximate damped inverse Fisher-vector product. + + Args: + vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape + [input_size, output_size] corresponding to layer's weights. If not, a + 2-tuple of the former and a Tensor of shape [output_size] corresponding + to the layer's bias. + + Returns: + Tensor of the same shape, corresponding to the inverse Fisher-vector + product. + """ + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply(self, vector): + """Approximate damped Fisher-vector product. + + Args: + vector: Tensor or 2-tuple of Tensors. if self._has_bias, Tensor of shape + [input_size, output_size] corresponding to layer's weights. If not, a + 2-tuple of the former and a Tensor of shape [output_size] corresponding + to the layer's bias. + + Returns: + Tensor of the same shape, corresponding to the Fisher-vector product. + """ + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def tensors_to_compute_grads(self): + """Tensors to compute derivative of loss with respect to.""" + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + result = len(self._inputs) + assert result == len(self._outputs) + return result + + +class ConvDiagonalFB(FisherBlock): + """FisherBlock for convolutional layers using a diagonal approx. + + Estimates the Fisher Information matrix's diagonal entries for a convolutional + layer. Unlike NaiveDiagonalFB this uses the low-variance "sum of squares" + estimator. + + Let 'params' be a vector parameterizing a model and 'i' an arbitrary index + into it. We are interested in Fisher(params)[i, i]. This is, + + Fisher(params)[i, i] = E[ v(x, y, params) v(x, y, params)^T ][i, i] + = E[ v(x, y, params)[i] ^ 2 ] + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For an example image 'x' that produces layer inputs 'a' and output + preactivations 's', + + v(x, y, w) = vec( sum_{loc} a_{loc} (d loss / d s_{loc})^T ) + + where 'loc' is a single (x, y) location in an image. + + This FisherBlock tracks Fisher(params)[i, i] for all indices 'i' corresponding + to the layer's parameters 'w'. + """ + + def __init__(self, layer_collection, params, strides, padding): + """Creates a ConvDiagonalFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (e.g. "SAME"). + """ + self._inputs = [] + self._outputs = [] + self._strides = tuple(strides) if isinstance(strides, list) else strides + self._padding = padding + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvDiagonalFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + # Concatenate inputs, grads_list into single Tensors. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + # Infer number of locations upon which convolution is applied. + inputs_shape = tuple(inputs.shape.as_list()) + self._num_locations = ( + inputs_shape[1] * inputs_shape[2] // + (self._strides[1] * self._strides[2])) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_locations ** NORMALIZE_DAMPING_POWER + self._damping = damping + + self._factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvDiagonalFactor, + (inputs, grads_list, self._filter_shape, self._strides, self._padding, + self._has_bias)) + + def multiply_inverse(self, vector): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect / (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply(self, vector): + reshaped_vect = utils.layer_params_to_mat2d(vector) + reshaped_out = reshaped_vect * (self._factor.get_cov() + self._damping) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to + the convolution. + outputs: Tensor of shape [batch_size, height, width, output_size]. Layer + preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + +class KroneckerProductFB(FisherBlock): + """A base class for FisherBlocks with separate input and output factors. + + The Fisher block is approximated as a Kronecker product of the input and + output factors. + """ + + def _register_damped_input_and_output_inverses(self, damping): + """Registers damped inverses for both the input and output factors. + + Sets the instance members _input_damping and _output_damping. Requires the + instance members _input_factor and _output_factor. + + Args: + damping: The base damping factor (float or Tensor) for the damped inverse. + """ + pi = utils.compute_pi(self._input_factor.get_cov(), + self._output_factor.get_cov()) + + self._input_damping = math_ops.sqrt(damping) * pi + self._output_damping = math_ops.sqrt(damping) / pi + + self._input_factor.register_damped_inverse(self._input_damping) + self._output_factor.register_damped_inverse(self._output_damping) + + @property + def _renorm_coeff(self): + """Kronecker factor multiplier coefficient. + + If this FisherBlock is represented as 'FB = c * kron(left, right)', then + this is 'c'. + + Returns: + 0-D Tensor. + """ + return 1.0 + + def multiply_inverse(self, vector): + left_factor_inv = self._input_factor.get_inverse(self._input_damping) + right_factor_inv = self._output_factor.get_inverse(self._output_damping) + reshaped_vector = utils.layer_params_to_mat2d(vector) + reshaped_out = math_ops.matmul(left_factor_inv, + math_ops.matmul(reshaped_vector, + right_factor_inv)) + if self._renorm_coeff != 1.0: + reshaped_out /= math_ops.cast( + self._renorm_coeff, dtype=reshaped_out.dtype) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def multiply(self, vector): + left_factor = self._input_factor.get_cov() + right_factor = self._output_factor.get_cov() + reshaped_vector = utils.layer_params_to_mat2d(vector) + reshaped_out = ( + math_ops.matmul(reshaped_vector, right_factor) + + self._output_damping * reshaped_vector) + reshaped_out = ( + math_ops.matmul(left_factor, reshaped_out) + + self._input_damping * reshaped_out) + if self._renorm_coeff != 1.0: + reshaped_out *= math_ops.cast( + self._renorm_coeff, dtype=reshaped_out.dtype) + return utils.mat2d_to_layer_params(vector, reshaped_out) + + def full_fisher_block(self): + """Explicitly constructs the full Fisher block. + + Used for testing purposes. (In general, the result may be very large.) + + Returns: + The full Fisher block. + """ + left_factor = self._input_factor.get_cov() + right_factor = self._output_factor.get_cov() + return self._renorm_coeff * utils.kronecker_product(left_factor, + right_factor) + + +class FullyConnectedKFACBasicFB(KroneckerProductFB): + """K-FAC FisherBlock for fully-connected (dense) layers. + + This uses the Kronecker-factorized approximation from the original + K-FAC paper (https://arxiv.org/abs/1503.05671) + """ + + def __init__(self, layer_collection, has_bias=False): + """Creates a FullyConnectedKFACBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + has_bias: Whether the component Kronecker factors have an additive bias. + (Default: False) + """ + self._inputs = [] + self._outputs = [] + self._has_bias = has_bias + + super(FullyConnectedKFACBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + """Instantiate Kronecker Factors for this FisherBlock. + + Args: + grads_list: List of list of Tensors. grads_list[i][j] is the + gradient of the loss with respect to 'outputs' from source 'i' and + tower 'j'. Each Tensor has shape [tower_minibatch_size, output_size]. + damping: 0-D Tensor or float. 'damping' * identity is approximately added + to this FisherBlock's Fisher approximation. + """ + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + self._input_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + ((inputs,), self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( # + fisher_factors.FullyConnectedKroneckerFactor, # + (grads_list,)) + self._register_damped_input_and_output_inverses(damping) + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to the + matrix-multiply. + outputs: Tensor of shape [batch_size, output_size]. Layer preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + +class ConvKFCBasicFB(KroneckerProductFB): + """FisherBlock for 2D convolutional layers using the basic KFC approx. + + Estimates the Fisher Information matrix's blog for a convolutional + layer. + + Consider a convoluational layer in this model with (unshared) filter matrix + 'w'. For a minibatch that produces inputs 'a' and output preactivations 's', + this FisherBlock estimates, + + F(w) = #locations * kronecker(E[flat(a) flat(a)^T], + E[flat(ds) flat(ds)^T]) + + where + + ds = (d / ds) log p(y | x, w) + #locations = number of (x, y) locations where 'w' is applied. + + where the expectation is taken over all examples and locations and flat() + concatenates an array's leading dimensions. + + See equation 23 in https://arxiv.org/abs/1602.01407 for details. + """ + + def __init__(self, layer_collection, params, strides, padding): + """Creates a ConvKFCBasicFB block. + + Args: + layer_collection: The collection of all layers in the K-FAC approximate + Fisher information matrix to which this FisherBlock belongs. + params: The parameters (Tensor or tuple of Tensors) of this layer. If + kernel alone, a Tensor of shape [kernel_height, kernel_width, + in_channels, out_channels]. If kernel and bias, a tuple of 2 elements + containing the previous and a Tensor of shape [out_channels]. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + """ + self._inputs = [] + self._outputs = [] + self._strides = tuple(strides) if isinstance(strides, list) else strides + self._padding = padding + self._has_bias = isinstance(params, (tuple, list)) + + fltr = params[0] if self._has_bias else params + self._filter_shape = tuple(fltr.shape.as_list()) + + super(ConvKFCBasicFB, self).__init__(layer_collection) + + def instantiate_factors(self, grads_list, damping): + # TODO(b/68033310): Validate which of, + # (1) summing on a single device (as below), or + # (2) on each device in isolation and aggregating + # is faster. + inputs = _concat_along_batch_dim(self._inputs) + grads_list = tuple(_concat_along_batch_dim(grads) for grads in grads_list) + + # Infer number of locations upon which convolution is applied. + self._num_locations = _num_conv_locations(inputs.shape.as_list(), + self._strides) + + self._input_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvInputKroneckerFactor, + (inputs, self._filter_shape, self._strides, self._padding, + self._has_bias)) + self._output_factor = self._layer_collection.make_or_get_factor( + fisher_factors.ConvOutputKroneckerFactor, (grads_list,)) + + if NORMALIZE_DAMPING_POWER: + damping /= self._num_locations**NORMALIZE_DAMPING_POWER + self._damping = damping + + self._register_damped_input_and_output_inverses(damping) + + @property + def _renorm_coeff(self): + return self._num_locations + + def tensors_to_compute_grads(self): + return self._outputs + + def register_additional_minibatch(self, inputs, outputs): + """Registers an additional minibatch to the FisherBlock. + + Args: + inputs: Tensor of shape [batch_size, height, width, input_size]. Inputs to + the convolution. + outputs: Tensor of shape [batch_size, height, width, output_size]. Layer + preactivations. + """ + self._inputs.append(inputs) + self._outputs.append(outputs) + + @property + def num_registered_minibatches(self): + return len(self._inputs) + + +def _concat_along_batch_dim(tensor_list): + """Concatenate tensors along batch (first) dimension. + + Args: + tensor_list: list of Tensors or list of tuples of Tensors. + + Returns: + Tensor or tuple of Tensors. + + Raises: + ValueError: If 'tensor_list' is empty. + + """ + if not tensor_list: + raise ValueError( + "Cannot concatenate Tensors if there are no Tensors to concatenate.") + + if isinstance(tensor_list[0], (tuple, list)): + # [(tensor1a, tensor1b), + # (tensor2a, tensor2b), ...] --> (tensor_a, tensor_b) + return tuple( + array_ops.concat(tensors, axis=0) for tensors in zip(*tensor_list)) + else: + # [tensor1, tensor2] --> tensor + return array_ops.concat(tensor_list, axis=0) + + +def _num_conv_locations(input_shape, strides): + """Returns the number of locations a Conv kernel is applied to.""" + return input_shape[1] * input_shape[2] // (strides[1] * strides[2]) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..59389f8d385c18f50914d690cfaa2825ef807ed3 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_blocks_lib.py @@ -0,0 +1,38 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherBlock definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_blocks import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'FisherBlock', + 'FullFB', + 'NaiveDiagonalFB', + 'FullyConnectedDiagonalFB', + 'KroneckerProductFB', + 'FullyConnectedKFACBasicFB', + 'ConvKFCBasicFB', + 'ConvDiagonalFB', + 'set_global_constants', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py new file mode 100644 index 0000000000000000000000000000000000000000..4e36813369e69de1d6f13ddb00566bda912244f6 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py @@ -0,0 +1,718 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import numpy as np +import six + +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import special_math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.training import moving_averages + +# Whether to initialize covariance estimators at a zero matrix (or the identity +# matrix). +INIT_COVARIANCES_AT_ZERO = False + +# Whether to zero-debias the moving averages. +ZERO_DEBIAS = False + +# When the number of inverses requested from a FisherFactor exceeds this value, +# the inverses are computed using an eigenvalue decomposition. +EIGENVALUE_DECOMPOSITION_THRESHOLD = 2 + +# Numerical eigenvalues computed from covariance matrix estimates are clipped to +# be at least as large as this value before they are used to compute inverses or +# matrix powers. Must be nonnegative. +EIGENVALUE_CLIPPING_THRESHOLD = 0.0 + + +def set_global_constants(init_covariances_at_zero=None, zero_debias=None, + eigenvalue_decomposition_threshold=None, + eigenvalue_clipping_threshold=None): + """Sets various global constants used by the classes in this module.""" + global INIT_COVARIANCES_AT_ZERO + global ZERO_DEBIAS + global EIGENVALUE_DECOMPOSITION_THRESHOLD + global EIGENVALUE_CLIPPING_THRESHOLD + + if init_covariances_at_zero is not None: + INIT_COVARIANCES_AT_ZERO = init_covariances_at_zero + if zero_debias is not None: + ZERO_DEBIAS = zero_debias + if eigenvalue_decomposition_threshold is not None: + EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold + if eigenvalue_clipping_threshold is not None: + EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold + + +def inverse_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + return array_ops.diag(array_ops.ones(shape[0], dtype)) + + +def covariance_initializer(shape, dtype, partition_info=None): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.diag(array_ops.zeros(shape[0], dtype)) + return array_ops.diag(array_ops.ones(shape[0], dtype)) + + +def diagonal_covariance_initializer(shape, dtype, partition_info): # pylint: disable=unused-argument + if INIT_COVARIANCES_AT_ZERO: + return array_ops.zeros(shape, dtype) + return array_ops.ones(shape, dtype) + + +def _compute_cov(tensor, normalizer=None): + """Compute the empirical second moment of the rows of a 2D Tensor. + + This function is meant to be applied to random matrices for which the true row + mean is zero, so that the true second moment equals the true covariance. + + Args: + tensor: A 2D Tensor. + normalizer: optional scalar for the estimator (by default, the normalizer is + the number of rows of tensor). + + Returns: + A square 2D Tensor with as many rows/cols as the number of input columns. + """ + if normalizer is None: + normalizer = array_ops.shape(tensor)[0] + cov = (math_ops.matmul(tensor, tensor, transpose_a=True) / math_ops.cast( + normalizer, tensor.dtype)) + return (cov + array_ops.transpose(cov)) / math_ops.cast(2, cov.dtype) + + +def _append_homog(tensor): + """Appends a homogeneous coordinate to the last dimension of a Tensor. + + Args: + tensor: A Tensor. + + Returns: + A Tensor identical to the input but one larger in the last dimension. The + new entries are filled with ones. + """ + rank = len(tensor.shape.as_list()) + shape = array_ops.concat([array_ops.shape(tensor)[:-1], [1]], axis=0) + ones = array_ops.ones(shape, dtype=tensor.dtype) + return array_ops.concat([tensor, ones], axis=rank-1) + + +def scope_string_from_params(params): + """Builds a variable scope string name from the given parameters. + + Supported parameters are: + * tensors + * booleans + * ints + * strings + * depth-1 tuples/lists of ints + * any depth tuples/lists of tensors + Other parameter types will throw an error. + + Args: + params: A parameter or list of parameters. + + Returns: + A string to use for the variable scope. + + Raises: + ValueError: if params includes an unsupported type. + """ + params = params if isinstance(params, (tuple, list)) else (params,) + + name_parts = [] + for param in params: + if isinstance(param, (tuple, list)): + if all([isinstance(p, int) for p in param]): + name_parts.append("-".join([str(p) for p in param])) + else: + name_parts.append(scope_string_from_name(param)) + elif isinstance(param, (str, int, bool)): + name_parts.append(str(param)) + elif isinstance(param, (tf_ops.Tensor, variables.Variable)): + name_parts.append(scope_string_from_name(param)) + else: + raise ValueError( + "Encountered an unsupported param type {}".format(type(param))) + return "_".join(name_parts) + + +def scope_string_from_name(tensor): + if isinstance(tensor, (tuple, list)): + return "__".join([scope_string_from_name(t) for t in tensor]) + # "gradients/add_4_grad/Reshape:0" -> "gradients_add_4_grad_Reshape" + return tensor.name.split(":")[0].replace("/", "_") + + +def scalar_or_tensor_to_string(val): + return repr(val) if np.isscalar(val) else scope_string_from_name(val) + + +@six.add_metaclass(abc.ABCMeta) +class FisherFactor(object): + """Base class for objects modeling factors of approximate Fisher blocks. + + Note that for blocks that aren't based on approximations, a 'factor' can + be the entire block itself, as is the case for the diagonal and full + representations. + + Subclasses must implement the _compute_new_cov method, and the _var_scope + and _cov_shape properties. + """ + + def __init__(self): + self.instantiate_covariance() + + @abc.abstractproperty + def _var_scope(self): + pass + + @abc.abstractproperty + def _cov_shape(self): + """The shape of the cov matrix.""" + pass + + @abc.abstractproperty + def _num_sources(self): + """The number of things to sum over when computing cov. + + The default make_covariance_update_op function will call _compute_new_cov + with indices ranging from 0 to _num_sources-1. The typical situation is + where the factor wants to sum the statistics it computes over multiple + backpropped "gradients" (typically passed in via "tensors" or + "outputs_grads" arguments). + """ + pass + + @property + def _cov_initializer(self): + return covariance_initializer + + def instantiate_covariance(self): + """Instantiates the covariance Variable as the instance member _cov.""" + with variable_scope.variable_scope(self._var_scope): + self._cov = variable_scope.get_variable( + "cov", + initializer=self._cov_initializer, + shape=self._cov_shape, + trainable=False) + + @abc.abstractmethod + def _compute_new_cov(self, idx=0): + pass + + def make_covariance_update_op(self, ema_decay): + """Constructs and returns the covariance update Op. + + Args: + ema_decay: The exponential moving average decay (float or Tensor). + Returns: + An Op for updating the covariance Variable referenced by _cov. + """ + new_cov = math_ops.add_n( + tuple(self._compute_new_cov(idx) for idx in range(self._num_sources))) + + return moving_averages.assign_moving_average( + self._cov, new_cov, ema_decay, zero_debias=ZERO_DEBIAS) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + return [] + + def get_cov(self): + return self._cov + + +class InverseProvidingFactor(FisherFactor): + """Base class for FisherFactors that maintain inverses, powers, etc of _cov. + + Assumes that the _cov property is a square PSD matrix. + + Subclasses must implement the _compute_new_cov method, and the _var_scope and + _cov_shape properties. + """ + + def __init__(self): + self._inverses_by_damping = {} + self._matpower_by_exp_and_damping = {} + self._eigendecomp = None + + super(InverseProvidingFactor, self).__init__() + + def register_damped_inverse(self, damping): + """Registers a damped inverse needed by a FisherBlock. + + Args: + damping: The damping value (float or Tensor) for this factor. + """ + if damping not in self._inverses_by_damping: + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + inv = variable_scope.get_variable( + "inv_damp{}".format(damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False) + self._inverses_by_damping[damping] = inv + + def register_matpower(self, exp, damping): + """Registers a matrix power needed by a FisherBlock. + + Args: + exp: The exponent (float or Tensor) to raise the matrix to. + damping: The damping value (float or Tensor). + """ + if (exp, damping) not in self._matpower_by_exp_and_damping: + exp_string = scalar_or_tensor_to_string(exp) + damping_string = scalar_or_tensor_to_string(damping) + with variable_scope.variable_scope(self._var_scope): + matpower = variable_scope.get_variable( + "matpower_exp{}_damp{}".format(exp_string, damping_string), + initializer=inverse_initializer, + shape=self._cov_shape, + trainable=False) + self._matpower_by_exp_and_damping[(exp, damping)] = matpower + + def register_eigendecomp(self): + """Registers that an eigendecomposition is needed by a FisherBlock.""" + if not self._eigendecomp: + self._eigendecomp = linalg_ops.self_adjoint_eig(self._cov) + + def make_inverse_update_ops(self): + """Create and return update ops corresponding to registered computations.""" + ops = super(InverseProvidingFactor, self).make_inverse_update_ops() + + num_inverses = len(self._inverses_by_damping) + matrix_power_registered = bool(self._matpower_by_exp_and_damping) + use_eig = (self._eigendecomp or matrix_power_registered or + num_inverses >= EIGENVALUE_DECOMPOSITION_THRESHOLD) + + if use_eig: + self.register_eigendecomp() # ensures self._eigendecomp is set + eigenvalues, eigenvectors = self._eigendecomp # pylint: disable=unpacking-non-sequence + + # The matrix self._cov is positive semidefinite by construction, but the + # numerical eigenvalues could be negative due to numerical errors, so here + # we clip them to be at least EIGENVALUE_CLIPPING_THRESHOLD. + clipped_eigenvalues = math_ops.maximum(eigenvalues, + EIGENVALUE_CLIPPING_THRESHOLD) + + for damping, inv in self._inverses_by_damping.items(): + ops.append( + inv.assign( + math_ops.matmul(eigenvectors / (clipped_eigenvalues + damping), + array_ops.transpose(eigenvectors)))) + + for (exp, damping), matpower in self._matpower_by_exp_and_damping.items(): + ops.append( + matpower.assign( + math_ops.matmul(eigenvectors * (clipped_eigenvalues + damping)** + exp, array_ops.transpose(eigenvectors)))) + else: + for damping, inv in self._inverses_by_damping.items(): + ops.append(inv.assign(utils.posdef_inv(self._cov, damping))) + + return ops + + def get_inverse(self, damping): + return self._inverses_by_damping[damping] + + def get_matpower(self, exp, damping): + return self._matpower_by_exp_and_damping[(exp, damping)] + + def get_eigendecomp(self): + return self._eigendecomp + + +class FullFactor(InverseProvidingFactor): + """FisherFactor for a full matrix representation of the Fisher of a parameter. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, params_grads, batch_size): + self._batch_size = batch_size + self._orig_params_grads_name = scope_string_from_params( + [params_grads, self._batch_size]) + self._params_grads_flat = tuple( + utils.tensors_to_column(params_grad) for params_grad in params_grads) + super(FullFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_full/" + self._orig_params_grads_name + + @property + def _cov_shape(self): + size = self._params_grads_flat[0].shape[0] + return [size, size] + + @property + def _num_sources(self): + return len(self._params_grads_flat) + + def _compute_new_cov(self, idx=0): + # This will be a very basic rank 1 estimate + return ((self._params_grads_flat[idx] * array_ops.transpose( + self._params_grads_flat[idx])) / math_ops.cast( + self._batch_size, self._params_grads_flat[idx].dtype)) + + +class DiagonalFactor(FisherFactor): + """A base class for FisherFactors that use diagonal approximations.""" + + def __init__(self): + super(DiagonalFactor, self).__init__() + + @property + def _cov_initializer(self): + return diagonal_covariance_initializer + + +class NaiveDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approximation of any type of param's Fisher. + + Note that this uses the naive "square the sum estimator", and so is applicable + to any type of parameter in principle, but has very high variance. + """ + + def __init__(self, params_grads, batch_size): + self._batch_size = batch_size + self._params_grads = tuple( + utils.tensors_to_column(params_grad) for params_grad in params_grads) + self._orig_params_grads_name = scope_string_from_params( + [self._params_grads, self._batch_size]) + super(NaiveDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_naivediag/" + self._orig_params_grads_name + + @property + def _cov_shape(self): + return self._params_grads[0].shape + + @property + def _num_sources(self): + return len(self._params_grads) + + def _compute_new_cov(self, idx=0): + return (math_ops.square(self._params_grads[idx]) / math_ops.cast( + self._batch_size, self._params_grads[idx].dtype)) + + +class FullyConnectedDiagonalFactor(DiagonalFactor): + r"""FisherFactor for a diagonal approx of a fully-connected layer's Fisher. + + Given in = [batch_size, input_size] and out_grad = [batch_size, output_size], + approximates the covariance as, + + Cov(in, out) = (1/batch_size) \sum_{i} outer(in[i], out_grad[i]) ** 2.0 + + where the square is taken element-wise. + """ + + # TODO(jamesmartens): add units tests for this class + + def __init__(self, inputs, outputs_grads, has_bias=False): + """Instantiate FullyConnectedDiagonalFactor. + + Args: + inputs: Tensor of shape [batch_size, input_size]. Inputs to fully + connected layer. + outputs_grads: List of Tensors of shape [batch_size, output_size]. + Gradient of loss with respect to layer's preactivations. + has_bias: bool. If True, append '1' to each input. + """ + self._outputs_grads = outputs_grads + self._batch_size = array_ops.shape(inputs)[0] + self._orig_tensors_name = scope_string_from_params((inputs,) + + tuple(outputs_grads)) + + # Note that we precompute the required operations on the inputs since the + # inputs don't change with the 'idx' argument to _compute_new_cov. (Only + # the target entry of _outputs_grads changes with idx.) + if has_bias: + inputs = _append_homog(inputs) + self._squared_inputs = math_ops.square(inputs) + + super(FullyConnectedDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_diagfc/" + self._orig_tensors_name + + @property + def _cov_shape(self): + return [self._squared_inputs.shape[1], self._outputs_grads[0].shape[1]] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + def _compute_new_cov(self, idx=0): + # The well-known special formula that uses the fact that the entry-wise + # square of an outer product is the outer-product of the entry-wise squares. + # The gradient is the outer product of the input and the output gradients, + # so we just square both and then take their outer-product. + new_cov = math_ops.matmul( + self._squared_inputs, + math_ops.square(self._outputs_grads[idx]), + transpose_a=True) + new_cov /= math_ops.cast(self._batch_size, new_cov.dtype) + return new_cov + + +class ConvDiagonalFactor(DiagonalFactor): + """FisherFactor for a diagonal approx of a convolutional layer's Fisher.""" + + # TODO(jamesmartens): add units tests for this class + + def __init__(self, inputs, outputs_grads, filter_shape, strides, padding, + has_bias=False): + """Creates a ConvDiagonalFactor object. + + Args: + inputs: Tensor of shape [batch_size, height, width, in_channels]. + Input activations to this layer. + outputs_grads: Tensor of shape [batch_size, height, width, out_channels]. + Per-example gradients to the loss with respect to the layer's output + preactivations. + filter_shape: Tuple of 4 ints: (kernel_height, kernel_width, in_channels, + out_channels). Represents shape of kernel used in this layer. + strides: The stride size in this layer (1-D Tensor of length 4). + padding: The padding in this layer (1-D of Tensor length 4). + has_bias: Python bool. If True, the layer is assumed to have a bias + parameter in addition to its filter parameter. + """ + self._filter_shape = filter_shape + self._has_bias = has_bias + self._outputs_grads = outputs_grads + + self._orig_tensors_name = scope_string_from_name((inputs,) + + tuple(outputs_grads)) + + # Note that we precompute the required operations on the inputs since the + # inputs don't change with the 'idx' argument to _compute_new_cov. (Only + # the target entry of _outputs_grads changes with idx.) + filter_height, filter_width, _, _ = self._filter_shape + patches = array_ops.extract_image_patches( + inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=strides, + rates=[1, 1, 1, 1], + padding=padding) + + if has_bias: + patches = _append_homog(patches) + + self._patches = patches + + super(ConvDiagonalFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convdiag/" + self._orig_tensors_name + + @property + def _cov_shape(self): + filter_height, filter_width, in_channels, out_channels = self._filter_shape + return [filter_height * filter_width * in_channels + self._has_bias, + out_channels] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + def _compute_new_cov(self, idx=0): + outputs_grad = self._outputs_grads[idx] + batch_size = array_ops.shape(self._patches)[0] + + new_cov = self._convdiag_sum_of_squares(self._patches, outputs_grad) + new_cov /= math_ops.cast(batch_size, new_cov.dtype) + + return new_cov + + def _convdiag_sum_of_squares(self, patches, outputs_grad): + # This computes the sum of the squares of the per-training-case "gradients". + # It does this simply by computing a giant tensor containing all of these, + # doing an entry-wise square, and them summing along the batch dimension. + case_wise_gradients = special_math_ops.einsum("bijk,bijl->bkl", patches, + outputs_grad) + return math_ops.reduce_sum(math_ops.square(case_wise_gradients), axis=0) + + +class FullyConnectedKroneckerFactor(InverseProvidingFactor): + """Kronecker factor for the input or output side of a fully-connected layer. + """ + + def __init__(self, tensors, has_bias=False): + """Instantiate FullyConnectedKroneckerFactor. + + Args: + tensors: List of Tensors of shape [batch_size, n]. Represents either a + layer's inputs or its output's gradients. + has_bias: bool. If True, assume this factor is for the layer's inputs and + append '1' to each row. + """ + # The tensor argument is either a tensor of input activations or a tensor of + # output pre-activation gradients. + self._has_bias = has_bias + self._tensors = tensors + super(FullyConnectedKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_fckron/" + scope_string_from_params( + [self._tensors, self._has_bias]) + + @property + def _cov_shape(self): + size = self._tensors[0].shape[1] + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return len(self._tensors) + + def _compute_new_cov(self, idx=0): + tensor = self._tensors[idx] + if self._has_bias: + tensor = _append_homog(tensor) + return _compute_cov(tensor) + + +class ConvInputKroneckerFactor(InverseProvidingFactor): + r"""Kronecker factor for the input side of a convolutional layer. + + Estimates E[ a a^T ] where a is the inputs to a convolutional layer given + example x. Expectation is taken over all examples and locations. + + Equivalent to \Omega in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, inputs, filter_shape, strides, padding, has_bias=False): + """Initializes ConvInputKroneckerFactor. + + Args: + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + filter_shape: 1-D Tensor of length 4. Contains [kernel_height, + kernel_width, in_channels, out_channels]. + strides: 1-D Tensor of length 4. Contains [batch_stride, height_stride, + width_stride, in_channel_stride]. + padding: str. Padding method for layer. "SAME" or "VALID". + has_bias: bool. If True, append 1 to in_channel. + """ + self._filter_shape = filter_shape + self._strides = strides + self._padding = padding + self._has_bias = has_bias + self._inputs = inputs + super(ConvInputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convinkron/" + scope_string_from_params([ + self._inputs, self._filter_shape, self._strides, self._padding, + self._has_bias + ]) + + @property + def _cov_shape(self): + filter_height, filter_width, in_channels, _ = self._filter_shape + size = filter_height * filter_width * in_channels + self._has_bias + return [size, size] + + @property + def _num_sources(self): + return 1 + + def _compute_new_cov(self, idx=0): + if idx != 0: + raise ValueError("ConvInputKroneckerFactor only supports idx = 0") + + # TODO(jamesmartens): factor this patches stuff out into a utility function + filter_height, filter_width, in_channels, _ = self._filter_shape + patches = array_ops.extract_image_patches( + self._inputs, + ksizes=[1, filter_height, filter_width, 1], + strides=self._strides, + rates=[1, 1, 1, 1], + padding=self._padding) + + flatten_size = (filter_height * filter_width * in_channels) + patches_flat = array_ops.reshape(patches, [-1, flatten_size]) + + if self._has_bias: + patches_flat = _append_homog(patches_flat) + + return _compute_cov(patches_flat) + + +class ConvOutputKroneckerFactor(InverseProvidingFactor): + r"""Kronecker factor for the output side of a convolutional layer. + + Estimates E[ ds ds^T ] where s is the preactivations of a convolutional layer + given example x and ds = (d / d s) log(p(y|x, w)). Expectation is taken over + all examples and locations. + + Equivalent to \Gamma in https://arxiv.org/abs/1602.01407 for details. See + Section 3.1 Estimating the factors. + """ + + def __init__(self, outputs_grads): + """Initializes ConvOutputKroneckerFactor. + + Args: + outputs_grads: list of Tensors. Each Tensor is of shape + [batch_size, height, width, out_channels]. + """ + self._out_channels = outputs_grads[0].shape.as_list()[3] + self._outputs_grads = outputs_grads + super(ConvOutputKroneckerFactor, self).__init__() + + @property + def _var_scope(self): + return "ff_convoutkron/" + scope_string_from_params(self._outputs_grads) + + @property + def _cov_shape(self): + size = self._out_channels + return [size, size] + + @property + def _num_sources(self): + return len(self._outputs_grads) + + def _compute_new_cov(self, idx=0): + reshaped_tensor = array_ops.reshape(self._outputs_grads[idx], + [-1, self._out_channels]) + return _compute_cov(reshaped_tensor) diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..23ee93cd405bbf719939df89d525c812ee061f8b --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/fisher_factors_lib.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""FisherFactor definitions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.fisher_factors import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "inverse_initializer", + "covariance_initializer", + "diagonal_covariance_initializer", + "scope_string_from_params", + "scope_string_from_name", + "scalar_or_tensor_to_string", + "FisherFactor", + "InverseProvidingFactor", + "FullFactor", + "DiagonalFactor", + "NaiveDiagonalFactor", + "FullyConnectedDiagonalFactor", + "FullyConnectedKroneckerFactor", + "ConvInputKroneckerFactor", + "ConvOutputKroneckerFactor", + "ConvDiagonalFactor", + "set_global_constants", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection.py b/tensorflow/contrib/kfac/python/ops/layer_collection.py new file mode 100644 index 0000000000000000000000000000000000000000..4eabb59b3e4e59c1c9ad4e3c1102efacb52dd478 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection.py @@ -0,0 +1,570 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from collections import defaultdict +from collections import OrderedDict + +import six + +from tensorflow.contrib.kfac.python.ops import fisher_blocks as fb +from tensorflow.contrib.kfac.python.ops import loss_functions as lf +from tensorflow.contrib.kfac.python.ops import utils +from tensorflow.python.framework import ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest + + +# Names for various approximations that can be requested for Fisher blocks. +APPROX_KRONECKER_NAME = "kron" +APPROX_DIAGONAL_NAME = "diagonal" +APPROX_FULL_NAME = "full" + +# Possible value for 'reuse' keyword argument. Sets 'reuse' to +# tf.get_variable_scope().reuse. +VARIABLE_SCOPE = "VARIABLE_SCOPE" + +# TODO(jamesmartens): need to add find_canonical_output back into this somewhere + + +class LayerParametersDict(OrderedDict): + """An OrderedDict where keys are Tensors or tuples of Tensors. + + Ensures that no Tensor is associated with two different keys. + """ + + def __init__(self, *args, **kwargs): + self._tensors = set() + super(LayerParametersDict, self).__init__(*args, **kwargs) + + def __setitem__(self, key, value): + key = self._canonicalize_key(key) + tensors = key if isinstance(key, (tuple, list)) else (key,) + key_collisions = self._tensors.intersection(tensors) + if key_collisions: + raise ValueError("Key(s) already present: {}".format(key_collisions)) + self._tensors.update(tensors) + super(LayerParametersDict, self).__setitem__(key, value) + + def __delitem__(self, key): + key = self._canonicalize_key(key) + self._tensors.remove(key) + super(LayerParametersDict, self).__delitem__(key) + + def __getitem__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__getitem__(key) + + def __contains__(self, key): + key = self._canonicalize_key(key) + return super(LayerParametersDict, self).__contains__(key) + + def _canonicalize_key(self, key): + if isinstance(key, (list, tuple)): + return tuple(key) + return key + + +# TODO(b/68034464): add capability for LayerCollection to be "finalized" +# and do this when it gets used by FisherEstimator / KfacOptimizer. + + +class LayerCollection(object): + """Registry of information about layers and losses. + + Note that you need to create a new one of these for each MatrixEstimator or + KfacOptimizer. + + Attributes: + fisher_blocks: a LayersParamsDict (subclass of OrderedDict) mapping layer + parameters (Tensors or tuples of Tensors) to FisherBlock instances. + fisher_factors: an OrderedDict mapping tuples to FisherFactor instances. + losses: a list of LossFunction objects. The loss to be optimized is their + sum. + """ + + def __init__(self, graph=None, name="LayerCollection"): + self.fisher_blocks = LayerParametersDict() + self.fisher_factors = OrderedDict() + self._graph = graph or ops.get_default_graph() + self._loss_dict = {} # {str: LossFunction} + self._subgraph = None + + with variable_scope.variable_scope(None, default_name=name) as scope: + self._var_scope = scope.name + + @property + def losses(self): + """LossFunctions registered with this LayerCollection.""" + return list(self._loss_dict.values()) + + def register_block(self, layer_key, fisher_block, reuse=VARIABLE_SCOPE): + """Validates and registers the layer_key associated with the fisher_block. + + Validation consists of checking whether the key was already registered or + if any of the elements of layer_key (if it's a tuple) were already + registered as part of another tuple (throws an error if so). If any of the + elements were registered by themselves, or as part of tuples that are + subsets of this layer_key, those registrations are first removed. + + If the layer_key is a subset of an existing registration, registration of + the new, smaller layer_key is skipped. + + e.g. If registrations include {'a': foo, ('b', 'c'): bar}, then + - register_layer('a', baz) -> ValueError + - register_layer(('b', 'c', 'd'), baz) -> + {'a': foo, ('b', 'c', 'd'): baz} + - register_layer('b', baz) -> + {'a': foo, ('b', 'c'): bar} (No change) + - register_layer(('a', 'd'), baz) -> + {('a', 'd'): baz, ('b', 'c'): bar} + - register_layer(('b', 'd'), baz) -> ValueError + + Args: + layer_key: The key to check for in existing registrations and to register + if valid. + fisher_block: The associated fisher block. + reuse: Method to use for inserting new FisherBlocks. One of True, False, + or VARIABLE_SCOPE. + + Raises: + ValueError: If the layer_key was already registered, or if a subset of the + layer_key has already been registered as part of a different tuple. + + Returns: + FisherBlock registered under 'layer_key'. May or may not be the same as + 'fisher_block'. + """ + if reuse is VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse is True or (reuse is variable_scope.AUTO_REUSE and + layer_key in self.fisher_blocks): + result = self.fisher_blocks[layer_key] + if type(result) != type(fisher_block): # pylint: disable=unidiomatic-typecheck + raise ValueError( + "Attempted to register FisherBlock of type %s when existing " + "FisherBlock has type %s." % (type(fisher_block), type(result))) + return result + if reuse is False and layer_key in self.fisher_blocks: + raise ValueError("FisherBlock for %s is already in LayerCollection." % + (layer_key,)) + + # Insert fisher_block into self.fisher_blocks. + if layer_key in self.fisher_blocks: + raise ValueError("Duplicate registration: {}".format(layer_key)) + if isinstance(layer_key, (tuple, list)): + return self._register_block_with_sequence_key(layer_key, fisher_block) + else: + return self._register_block_with_nonsequence_key(layer_key, fisher_block) + + def _register_block_with_sequence_key(self, layer_key, fisher_block): + """Validates and registers the layer_key if it's a sequence.""" + # Find all keys that are either supersets or subsets of 'layer_key'. + inclusions = { + fisher_elt + for layer_elt in layer_key for fisher_elt in self.fisher_blocks + if self._equal_or_subset(layer_elt, fisher_elt) + } + + if not inclusions: + self.fisher_blocks[layer_key] = fisher_block + return fisher_block + + result_key = None + for key in inclusions: + fisher_block_key = key if isinstance(key, (tuple, list)) else (key,) + in_existing_only = set(fisher_block_key) - set(layer_key) + in_new_only = set(layer_key) - set(fisher_block_key) + + if in_existing_only and in_new_only: + # Existing and new key have an intersection but neither is a subset of + # the other. This is an error. + raise ValueError( + "Inconsistent registration, expected new key to be a subset or " + "superset of the existing key: existing is {}, new is {}".format( + key, layer_key)) + elif in_existing_only and not in_new_only: + # Existing key is strict superset of new key. Return existing + # FisherBlock. + logging.warning("Graph Registration Warning: tried to register " + "a subset ({}) of an already registered tuple " + "({}), skipping".format(layer_key, fisher_block_key)) + assert result_key is None + result_key = key + elif in_new_only and not in_existing_only: + # Existing key is a strict subset of new key. Replace existing + # FisherBlock with new one. + # + # TODO(b/68715045): This is dangerous. If there are existing + # registrations for a minibatch from elsewhere in the graph, they won't + # be re-registered with this new FisherBlock. The type of FisherBlock + # could also change here. + logging.warning( + "Replacing existing FisherBlock for key {} with new FisherBlock " + "for key {}. {} registered minibatches from the existing " + "FisherBlock will not be migrated.".format( + key, layer_key, + self.fisher_blocks[key].num_registered_minibatches)) + self.fisher_blocks.pop(key) + self.fisher_blocks[layer_key] = fisher_block + assert result_key is None + result_key = layer_key + elif not in_new_only and not in_existing_only: + # Existing and new are identical. Reuse the old FisherBlock. + # + # TODO(b/68715045): This is dangerous. If the new FisherBlock has + # existing registered minibatches, they will not be migrated to the + # existing FisherBlock. + assert result_key is None + result_key = key + else: + raise ValueError("Unexpected layer key conflict: {} vs. {}".format( + layer_key, key)) + + return self.fisher_blocks[result_key] + + def _register_block_with_nonsequence_key(self, layer_key, fisher_block): + """Validates and registers the layer_key if it's not a sequence.""" + inclusions = { + fisher_elt + for fisher_elt in self.fisher_blocks + if self._equal_or_subset(layer_key, fisher_elt) + } + + if not inclusions: + self.fisher_blocks[layer_key] = fisher_block + else: + logging.warning("Graph Registration Warning: tried to register " + "variable ({}) but a containing tuple was already " + "registered ({}), skipping".format(layer_key, inclusions)) + + return fisher_block + + def _equal_or_subset(self, elt1, elt2): + """Checks if the elements are equal or one is contained in the other.""" + return (elt1 == elt2 or (isinstance(elt1, + (tuple, list)) and elt2 in elt1) or + (isinstance(elt2, (tuple, list)) and elt1 in elt2)) + + def get_use_count_map(self): + """Returns a dict of variables to their number of registrations.""" + vars_to_uses = defaultdict(int) + for key, block in six.iteritems(self.fisher_blocks): + key = key if isinstance(key, (tuple, list)) else (key,) + for k in key: + vars_to_uses[k] += block.num_registered_minibatches + return vars_to_uses + + def get_blocks(self): + return self.fisher_blocks.values() + + def get_factors(self): + return self.fisher_factors.values() + + @property + def graph(self): + return self._graph + + @property + def subgraph(self): + return self._subgraph + + def create_subgraph(self): + if not self.losses: + raise ValueError("Must have at least one registered loss.") + inputs_to_losses = nest.flatten(tuple(loss.inputs for loss in self.losses)) + self._subgraph = utils.SubGraph(inputs_to_losses) + + def total_loss(self): + return math_ops.add_n(tuple(loss.evaluate() for loss in self.losses)) + + def total_sampled_loss(self): + return math_ops.add_n( + tuple(loss.evaluate_on_sample() for loss in self.losses)) + + def register_fully_connected(self, + params, + inputs, + outputs, + approx=APPROX_KRONECKER_NAME, + reuse=VARIABLE_SCOPE): + """Registers a fully connnected layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [input_size, output_size]. + Bias should have shape [output_size]. + inputs: Tensor of shape [batch_size, input_size]. Inputs to layer. + outputs: Tensor of shape [batch_size, output_size]. Preactivations + produced by layer. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_KRONECKER_NAME: fb.FullyConnectedKFACBasicFB, + APPROX_DIAGONAL_NAME: fb.FullyConnectedDiagonalFB, + } + + if approx not in approx_to_block_types: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = approx_to_block_types[approx] + has_bias = isinstance(params, (tuple, list)) + + block = self.register_block(params, block_type(self, has_bias), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + + def register_conv2d(self, + params, + strides, + padding, + inputs, + outputs, + approx=APPROX_KRONECKER_NAME, + reuse=VARIABLE_SCOPE): + """Registers a convolutional layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + strides: 1-D Tensor of length 4. Strides for convolution kernel. + padding: string. see tf.nn.conv2d for valid values. + inputs: Tensor of shape [batch_size, height, width, in_channels]. Inputs + to layer. + outputs: Tensor of shape [batch_size, height, width, out_channels]. + Preactivations produced by layer. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_KRONECKER_NAME: fb.ConvKFCBasicFB, + APPROX_DIAGONAL_NAME: fb.ConvDiagonalFB, + } + + if approx not in approx_to_block_types: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = approx_to_block_types[approx] + block = self.register_block( + params, block_type(self, params, strides, padding), reuse=reuse) + block.register_additional_minibatch(inputs, outputs) + + def register_generic(self, + params, + batch_size, + approx=APPROX_DIAGONAL_NAME, + reuse=VARIABLE_SCOPE): + """Registers a generic layer. + + Args: + params: Tensor or 2-tuple of Tensors corresponding to weight and bias of + this layer. Weight matrix should have shape [kernel_height, + kernel_width, in_channels, out_channels]. Bias should have shape + [out_channels]. + batch_size: 0-D Tensor. Size of the minibatch. + approx: str. One of APPROX_KRONECKER_NAME or APPROX_DIAGONAL_NAME. + reuse: bool or str. If True, reuse an existing FisherBlock. If False, + create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: For improper value to 'approx'. + KeyError: If reuse == True but no FisherBlock found for 'params'. + ValueError: If reuse == True and FisherBlock found but of the wrong type. + """ + approx_to_block_types = { + APPROX_FULL_NAME: fb.FullFB, + APPROX_DIAGONAL_NAME: fb.NaiveDiagonalFB, + } + + if approx not in approx_to_block_types: + raise ValueError("Bad value {} for approx.".format(approx)) + + block_type = approx_to_block_types[approx] + block = self.register_block(params, block_type(self, params), reuse=reuse) + block.register_additional_minibatch(batch_size) + + def register_categorical_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None, + reuse=VARIABLE_SCOPE): + """Registers a categorical predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + reuse: (OPTIONAL) bool or str. If True, reuse an existing FisherBlock. + If False, create a new FisherBlock. If VARIABLE_SCOPE, use + tf.get_variable_scope().reuse. + + Raises: + ValueError: If reuse=True and name != None. + ValueError: If reuse=True and seed != None. + KeyError: If reuse=True and no existing LossFunction with 'name' found. + KeyError: If reuse=False and existing LossFunction with 'name' found. + """ + name = name or self._graph.unique_name( + "register_categorical_predictive_distribution") + + if reuse == VARIABLE_SCOPE: + reuse = variable_scope.get_variable_scope().reuse + + if reuse: + if name is None: + raise ValueError( + "If reuse is enabled, loss function's name must be set.") + if seed is not None: + raise ValueError( + "Seed can only be specified at LossFunction instantiation.") + + loss = self._loss_dict.get(name, None) + + if loss is None: + raise KeyError( + "Unable to find loss function named {}. Create a new LossFunction " + "with reuse=False.".format(name)) + + loss.register_additional_minibatch(logits, targets=targets) + else: + if name in self._loss_dict: + raise KeyError( + "Loss function named {} already exists. Set reuse=True to append " + "another minibatch.".format(name)) + loss = lf.CategoricalLogitsNegativeLogProbLoss( + logits, targets=targets, seed=seed) + self._loss_dict[name] = loss + + def register_normal_predictive_distribution(self, + mean, + var=0.5, + seed=None, + targets=None, + name=None): + """Registers a normal predictive distribution. + + Args: + mean: The mean vector defining the distribution. + var: The variance (must be a scalar). Note that the default value of + 0.5 corresponds to a standard squared error loss (target - + prediction)**2. If your squared error loss is of the form + 0.5*(target - prediction)**2 you should use var=1.0. (Default: 0.5) + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + """ + name = name or self._graph.unique_name( + "register_normal_predictive_distribution") + if name in self._loss_dict: + raise NotImplementedError( + "Adding logits to an existing LossFunction not yet supported.") + loss = lf.NormalMeanNegativeLogProbLoss( + mean, var, targets=targets, seed=seed) + self._loss_dict[name] = loss + + def register_multi_bernoulli_predictive_distribution(self, + logits, + seed=None, + targets=None, + name=None): + """Registers a multi-Bernoulli predictive distribution. + + Args: + logits: The logits of the distribution (i.e. its parameters). + seed: The seed for the RNG (for debugging) (Default: None) + targets: (OPTIONAL) The targets for the loss function. Only required if + one wants to call total_loss() instead of total_sampled_loss(). + total_loss() is required, for example, to estimate the + "empirical Fisher" (instead of the true Fisher). + (Default: None) + name: (OPTIONAL) str or None. Unique name for this loss function. If None, + a new name is generated. (Default: None) + """ + name = name or self._graph.unique_name( + "register_multi_bernoulli_predictive_distribution") + if name in self._loss_dict: + raise NotImplementedError( + "Adding logits to an existing LossFunction not yet supported.") + loss = lf.MultiBernoulliNegativeLogProbLoss( + logits, targets=targets, seed=seed) + self._loss_dict[name] = loss + + def make_or_get_factor(self, cls, args): + """Insert 'cls(args)' into 'self.fisher_factors' if not already present. + + Wraps constructor in 'tf.variable_scope()' to ensure variables constructed + in 'cls.__init__' are placed under this LayerCollection's scope. + + Args: + cls: Class that implements FisherFactor. + args: Tuple of arguments to pass into 'cls's constructor. Must be + hashable. + + Returns: + Instance of 'cls' found in self.fisher_factors. + """ + try: + hash(args) + except TypeError: + raise TypeError(( + "Unable to use (cls, args) = ({}, {}) as a key in " + "LayerCollection.fisher_factors. The pair cannot be hashed." + ).format(cls, args)) + + with variable_scope.variable_scope(self._var_scope): + return utils.setdefault(self.fisher_factors, (cls, args), + lambda: cls(*args)) diff --git a/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..d6bf61a210203dd74d4e93b65005f660b1fab4ff --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/layer_collection_lib.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Registry for layers and their parameters/variables. + +This represents the collection of all layers in the approximate Fisher +information matrix to which a particular FisherBlock may belong. That is, we +might have several layer collections for one TF graph (if we have multiple K-FAC +optimizers being used, for example.) +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.layer_collection import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "LayerParametersDict", + "LayerCollection", + "APPROX_KRONECKER_NAME", + "APPROX_DIAGONAL_NAME", + "APPROX_FULL_NAME", + "VARIABLE_SCOPE", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions.py b/tensorflow/contrib/kfac/python/ops/loss_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfde7f9ababab73980e93ea1dd65be1b559712b --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions.py @@ -0,0 +1,757 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops.distributions import bernoulli +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.ops.distributions import normal + + +@six.add_metaclass(abc.ABCMeta) +class LossFunction(object): + """Abstract base class for loss functions. + + Note that unlike typical loss functions used in neural networks these are + summed and not averaged across cases in the batch, since this is what the + users of this class (FisherEstimator and MatrixVectorProductComputer) will + be expecting. The implication of this is that you will may want to + normalize things like Fisher-vector products by the batch size when you + use this class. It depends on the use case. + """ + + @abc.abstractproperty + def targets(self): + """The targets being predicted by the model. + + Returns: + None or Tensor of appropriate shape for calling self._evaluate() on. + """ + pass + + @abc.abstractproperty + def inputs(self): + """The inputs to the loss function (excluding the targets).""" + pass + + def evaluate(self): + """Evaluate the loss function on the targets.""" + if self.targets is not None: + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.targets)) + else: + raise Exception("Cannot evaluate losses with unspecified targets.") + + @abc.abstractmethod + def _evaluate(self, targets): + """Evaluates the log probability of the targets. + + Args: + targets: Tensor that distribution can calculate log_prob() of. + + Returns: + log probability of each target, summed across all targets. + """ + + pass + + @abc.abstractmethod + def multiply_hessian(self, vector): + """Right-multiply a vector by the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Hessian. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor(self, vector): + """Right-multiply a vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'hessian_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'hessian_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_hessian_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Hessian. + + Here the 'Hessian' is the Hessian matrix (i.e. matrix of 2nd-derivatives) + of the loss function with respect to its inputs. Typically this will be + block-diagonal across different cases in the batch, since the loss function + is typically summed across cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Hessian, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'hessian_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B^T. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def hessian_factor_inner_shape(self): + """The shape of the tensor returned by multiply_hessian_factor.""" + pass + + @abc.abstractproperty + def hessian_factor_inner_static_shape(self): + """Static version of hessian_factor_inner_shape.""" + pass + + +@six.add_metaclass(abc.ABCMeta) +class NegativeLogProbLoss(LossFunction): + """Abstract base class for loss functions that are negative log probs.""" + + def __init__(self, seed=None): + self._default_seed = seed + super(NegativeLogProbLoss, self).__init__() + + @property + def inputs(self): + return self.params + + @abc.abstractproperty + def params(self): + """Parameters to the underlying distribution.""" + pass + + @abc.abstractmethod + def multiply_fisher(self, vector): + """Right-multiply a vector by the Fisher. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by the Fisher. Will be of the same shape(s) + as the 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor(self, vector): + """Right-multiply a vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribtion (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be of the shape given by the + 'fisher_factor_inner_shape' property. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_transpose(self, vector): + """Right-multiply a vector by the transpose of a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribtion (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + Note that B can be any matrix satisfying B * B^T = F where F is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + vector: The vector to multiply. Must be the same shape(s) as the + 'inputs' property. + + Returns: + The vector right-multiplied by B^T. Will be of the shape given by the + 'fisher_factor_inner_shape' property. + """ + pass + + @abc.abstractmethod + def multiply_fisher_factor_replicated_one_hot(self, index): + """Right-multiply a replicated-one-hot vector by a factor B of the Fisher. + + Here the 'Fisher' is the Fisher information matrix (i.e. expected outer- + product of gradients) with respect to the parameters of the underlying + probability distribtion (whose log-prob defines the loss). Typically this + will be block-diagonal across different cases in the batch, since the + distribution is usually (but not always) conditionally iid across different + cases. + + A 'replicated-one-hot' vector means a tensor which, for each slice along the + batch dimension (assumed to be dimension 0), is 1.0 in the entry + corresponding to the given index and 0 elsewhere. + + Note that B can be any matrix satisfying B * B^T = H where H is the Fisher, + but will agree with the one used in the other methods of this class. + + Args: + index: A tuple representing in the index of the entry in each slice that + is 1.0. Note that len(index) must be equal to the number of elements + of the 'fisher_factor_inner_shape' tensor minus one. + + Returns: + The vector right-multiplied by B. Will be of the same shape(s) as the + 'inputs' property. + """ + pass + + @abc.abstractproperty + def fisher_factor_inner_shape(self): + """The shape of the tensor returned by multiply_fisher_factor.""" + pass + + @abc.abstractproperty + def fisher_factor_inner_static_shape(self): + """Static version of fisher_factor_inner_shape.""" + pass + + @abc.abstractmethod + def sample(self, seed): + """Sample 'targets' from the underlying distribution.""" + pass + + def evaluate_on_sample(self, seed=None): + """Evaluates the log probability on a random sample. + + Args: + seed: int or None. Random seed for this draw from the distribution. + + Returns: + Log probability of sampled targets, summed across examples. + """ + if seed is None: + seed = self._default_seed + # We treat the targets as "constant". It's only the inputs that get + # "back-propped" through. + return self._evaluate(array_ops.stop_gradient(self.sample(seed))) + + +# TODO(jamesmartens): should this just inherit from object to avoid "diamond" +# inheritance, or is there a better way? +class NaturalParamsNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses whose inputs are 'natural' parameters. + + Note that the Hessian and Fisher for natural parameters of exponential- + family models are the same, hence the purpose of this class. + See here: https://arxiv.org/abs/1412.1193 + + 'Natural parameters' are defined for exponential-family models. See for + example: https://en.wikipedia.org/wiki/Exponential_family + """ + + def multiply_hessian(self, vector): + return self.multiply_fisher(vector) + + def multiply_hessian_factor(self, vector): + return self.multiply_fisher_factor(vector) + + def multiply_hessian_factor_transpose(self, vector): + return self.multiply_fisher_factor_transpose(vector) + + def multiply_hessian_factor_replicated_one_hot(self, index): + return self.multiply_fisher_factor_replicated_one_hot(index) + + @property + def hessian_factor_inner_shape(self): + return self.fisher_factor_inner_shape + + @property + def hessian_factor_inner_static_shape(self): + return self.fisher_factor_inner_shape + + +class DistributionNegativeLogProbLoss(NegativeLogProbLoss): + """Base class for neg log prob losses that use the TF Distribution classes.""" + + def __init__(self, seed=None): + super(DistributionNegativeLogProbLoss, self).__init__(seed=seed) + + @abc.abstractproperty + def dist(self): + """The underlying tf.distributions.Distribution.""" + pass + + def _evaluate(self, targets): + return -math_ops.reduce_sum(self.dist.log_prob(targets)) + + def sample(self, seed): + return self.dist.sample(seed=seed) + + +class NormalMeanNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a normal distribution parameterized by a mean vector. + + + Note that the covariance is treated as a constant 'var' times the identity. + Also note that the Fisher for such a normal distribution with respect the mean + parameter is given by: + + F = (1/var) * I + + See for example https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf. + """ + + def __init__(self, mean, var=0.5, targets=None, seed=None): + self._mean = mean + self._var = var + self._targets = targets + super(NormalMeanNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=math_ops.sqrt(self._var)) + + @property + def params(self): + return self._mean + + def multiply_fisher(self, vector): + return (1. / self._var) * vector + + def multiply_fisher_factor(self, vector): + return self._var**-0.5 * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + ones_slice = array_ops.expand_dims( + array_ops.ones(array_ops.shape(self._mean)[:1], dtype=self._mean.dtype), + axis=-1) + output_slice = self._var**-0.5 * ones_slice + return insert_slice_in_zeros(output_slice, 1, + int(self._mean.shape[1]), index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._mean) + + @property + def fisher_factor_inner_static_shape(self): + return self._mean.shape + + +class NormalMeanVarianceNegativeLogProbLoss(DistributionNegativeLogProbLoss): + """Negative log prob loss for a normal distribution with mean and variance. + + This class parameterizes a multivariate normal distribution with n independent + dimensions. Unlike `NormalMeanNegativeLogProbLoss`, this class does not + assume the variance is held constant. The Fisher Information for n = 1 + is given by, + + F = [[1 / variance, 0], + [ 0, 0.5 / variance^2]] + + where the parameters of the distribution are concatenated into a single + vector as [mean, variance]. For n > 1, the mean parameter vector is + concatenated with the variance parameter vector. + + See https://www.ii.pwr.edu.pl/~tomczak/PDF/[JMT]Fisher_inf.pdf for derivation. + """ + + def __init__(self, mean, variance, targets=None, seed=None): + assert len(mean.shape) == 2, "Expect 2D mean tensor." + assert len(variance.shape) == 2, "Expect 2D variance tensor." + self._mean = mean + self._variance = variance + self._scale = math_ops.sqrt(variance) + self._targets = targets + super(NormalMeanVarianceNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return normal.Normal(loc=self._mean, scale=self._scale) + + @property + def params(self): + return self._mean, self._variance + + def _concat(self, mean, variance): + return array_ops.concat([mean, variance], axis=-1) + + def _split(self, params): + return array_ops.split(params, 2, axis=-1) + + @property + def _fisher_mean(self): + return 1./self._variance + + @property + def _fisher_mean_factor(self): + return 1./self._scale + + @property + def _fisher_var(self): + return 1./(2*math_ops.square(self._variance)) + + @property + def _fisher_var_factor(self): + return 1./(math_ops.sqrt(2.)*self._variance) + + def multiply_fisher(self, vecs): + mean_vec, var_vec = vecs + return (self._fisher_mean * mean_vec, + self._fisher_var * var_vec) + + def multiply_fisher_factor(self, vecs): + mean_vec, var_vec = self._split(vecs) + return (self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_transpose(self, vecs): + mean_vec, var_vec = vecs + return self._concat(self._fisher_mean_factor * mean_vec, + self._fisher_var_factor * var_vec) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + index = index[0] + + if index < int(self._mean.shape[-1]): + # Index corresponds to mean parameter. + mean_slice = self._fisher_mean_factor[:, index] + mean_slice = array_ops.expand_dims(mean_slice, axis=-1) + mean_output = insert_slice_in_zeros(mean_slice, 1, + int(self._mean.shape[1]), index) + var_output = array_ops.zeros_like(mean_output) + else: + index -= int(self._mean.shape[-1]) + # Index corresponds to variance parameter. + var_slice = self._fisher_var_factor[:, index] + var_slice = array_ops.expand_dims(var_slice, axis=-1) + var_output = insert_slice_in_zeros(var_slice, 1, + int(self._variance.shape[1]), index) + mean_output = array_ops.zeros_like(var_output) + + return mean_output, var_output + + @property + def fisher_factor_inner_shape(self): + return array_ops.concat([array_ops.shape(self._mean)[:-1], + 2*array_ops.shape(self._mean)[-1:]], axis=0) + + @property + def fisher_factor_inner_static_shape(self): + shape = self._mean.shape.as_list() + return tensor_shape.TensorShape(shape[-1:] + [2*shape[-1]]) + + def multiply_hessian(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_transpose(self, vector): + raise NotImplementedError() + + def multiply_hessian_factor_replicated_one_hot(self, index): + raise NotImplementedError() + + @property + def hessian_factor_inner_shape(self): + raise NotImplementedError() + + @property + def hessian_factor_inner_static_shape(self): + raise NotImplementedError() + + +class CategoricalLogitsNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for a categorical distribution parameterized by logits. + + + Note that the Fisher (for a single case) of a categorical distribution, with + respect to the natural parameters (i.e. the logits), is given by: + + F = diag(p) - p*p^T + + where p = softmax(logits). F can be factorized as F = B * B^T where + + B = diag(q) - p*q^T + + where q is the entry-wise square root of p. This is easy to verify using the + fact that q^T*q = 1. + """ + + def __init__(self, logits, targets=None, seed=None): + """Instantiates a CategoricalLogitsNegativeLogProbLoss. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [output_size]. Each elements contains an + index in [0, output_size). + seed: int or None. Default random seed when sampling. + """ + self._logits_components = [] + self._targets_components = [] + self.register_additional_minibatch(logits, targets=targets) + super(CategoricalLogitsNegativeLogProbLoss, self).__init__(seed=seed) + + def register_additional_minibatch(self, logits, targets=None): + """Register an additiona minibatch's worth of parameters. + + Args: + logits: Tensor of shape [batch_size, output_size]. Parameters for + underlying distribution. + targets: None or Tensor of shape [batch_size, output_size]. Each row must + be a one-hot vector. + """ + self._logits_components.append(logits) + self._targets_components.append(targets) + + @property + def _logits(self): + return array_ops.concat(self._logits_components, axis=0) + + @property + def targets(self): + if all(target is None for target in self._targets_components): + return None + return array_ops.concat(self._targets_components, axis=0) + + @property + def dist(self): + return categorical.Categorical(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def _sqrt_probs(self): + return math_ops.sqrt(self._probs) + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + probs = self._probs + return vector * probs - math_ops.reduce_sum(vector * probs, axis=1) * probs + + def multiply_fisher_factor(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - probs * math_ops.reduce_sum( + sqrt_probs * vector, axis=1, keep_dims=True) + + def multiply_fisher_factor_transpose(self, vector): + probs = self._probs + sqrt_probs = self._sqrt_probs + return sqrt_probs * vector - sqrt_probs * math_ops.reduce_sum( + probs * vector, axis=1, keep_dims=True) + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs = self._probs + sqrt_probs = self._sqrt_probs + sqrt_probs_slice = array_ops.expand_dims(sqrt_probs[:, index[0]], -1) + padded_slice = insert_slice_in_zeros(sqrt_probs_slice, 1, + int(sqrt_probs.shape[1]), index[0]) + return padded_slice - probs * sqrt_probs_slice + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +class MultiBernoulliNegativeLogProbLoss(DistributionNegativeLogProbLoss, + NaturalParamsNegativeLogProbLoss): + """Neg log prob loss for multiple Bernoulli distributions param'd by logits. + + Represents N independent Bernoulli distributions where N = len(logits). Its + Fisher Information matrix is given by, + + F = diag(p * (1-p)) + p = sigmoid(logits) + + As F is diagonal with positive entries, its factor B is, + + B = diag(sqrt(p * (1-p))) + """ + + def __init__(self, logits, targets=None, seed=None): + self._logits = logits + self._targets = targets + super(MultiBernoulliNegativeLogProbLoss, self).__init__(seed=seed) + + @property + def targets(self): + return self._targets + + @property + def dist(self): + return bernoulli.Bernoulli(logits=self._logits) + + @property + def _probs(self): + return self.dist.probs + + @property + def params(self): + return self._logits + + def multiply_fisher(self, vector): + return self._probs * (1 - self._probs) * vector + + def multiply_fisher_factor(self, vector): + return math_ops.sqrt(self._probs * (1 - self._probs)) * vector + + def multiply_fisher_factor_transpose(self, vector): + return self.multiply_fisher_factor(vector) # it's symmetric in this case + + def multiply_fisher_factor_replicated_one_hot(self, index): + assert len(index) == 1, "Length of index was {}".format(len(index)) + probs_slice = array_ops.expand_dims(self._probs[:, index[0]], -1) + output_slice = math_ops.sqrt(probs_slice * (1 - probs_slice)) + return insert_slice_in_zeros(output_slice, 1, + int(self._logits.shape[1]), index[0]) + + @property + def fisher_factor_inner_shape(self): + return array_ops.shape(self._logits) + + @property + def fisher_factor_inner_static_shape(self): + return self._logits.shape + + +def insert_slice_in_zeros(slice_to_insert, dim, dim_size, position): + """Inserts slice into a larger tensor of zeros. + + Forms a new tensor which is the same shape as slice_to_insert, except that + the dimension given by 'dim' is expanded to the size given by 'dim_size'. + 'position' determines the position (index) at which to insert the slice within + that dimension. + + Assumes slice_to_insert.shape[dim] = 1. + + Args: + slice_to_insert: The slice to insert. + dim: The dimension which to expand with zeros. + dim_size: The new size of the 'dim' dimension. + position: The position of 'slice_to_insert' in the new tensor. + + Returns: + The new tensor. + + Raises: + ValueError: If the slice's shape at the given dim is not 1. + """ + slice_shape = slice_to_insert.shape + if slice_shape[dim] != 1: + raise ValueError("Expected slice_to_insert.shape to have {} dim of 1, but " + "was {}".format(dim, slice_to_insert.shape[dim])) + + before = [0] * int(len(slice_shape)) + after = before[:] + before[dim] = position + after[dim] = dim_size - position - 1 + + return array_ops.pad(slice_to_insert, list(zip(before, after))) diff --git a/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..e9bb4f14e9e24128382832fcdaccdc9b24017046 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/loss_functions_lib.py @@ -0,0 +1,39 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Loss functions to be used by LayerCollection.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.loss_functions import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "LossFunction", + "NegativeLogProbLoss", + "NaturalParamsNegativeLogProbLoss", + "DistributionNegativeLogProbLoss", + "NormalMeanNegativeLogProbLoss", + "NormalMeanVarianceNegativeLogProbLoss", + "CategoricalLogitsNegativeLogProbLoss", + "MultiBernoulliNegativeLogProbLoss", + "MultiBernoulliNegativeLogProbLoss", + "insert_slice_in_zeros", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/op_queue.py b/tensorflow/contrib/kfac/python/ops/op_queue.py new file mode 100644 index 0000000000000000000000000000000000000000..831870fca451c585cb1a1dc6b24aad757e2bbaa8 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue.py @@ -0,0 +1,69 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.data.python.ops import dataset_ops +from tensorflow.python.framework import ops as tf_ops + + +class OpQueue(object): + """Class for choosing which Op to run next. + + Constructs an infinitely repeating sequence of Ops in shuffled order. + + In K-FAC, this can be used to distribute inverse update operations among + workers. + """ + + def __init__(self, ops, seed=None): + """Initializes an OpQueue. + + Args: + ops: list of TensorFlow Ops. Ops to be selected from. All workers must + initialize with the same set of ops. + seed: int or None. Random seed used when shuffling order of ops. + """ + self._ops_by_name = {op.name: op for op in ops} + + # Construct a (shuffled) Dataset with Op names. + op_names = tf_ops.convert_to_tensor(list(sorted(op.name for op in ops))) + op_names_dataset = (dataset_ops.Dataset.from_tensor_slices(op_names) + .shuffle(len(ops), seed=seed).repeat()) + self._next_op_name = op_names_dataset.make_one_shot_iterator().get_next() + + @property + def ops(self): + """Ops this OpQueue can return in next_op().""" + return self._ops_by_name.values() + + def next_op(self, sess): + """Chooses which op to run next. + + Note: This call will make a call to sess.run(). + + Args: + sess: tf.Session. + + Returns: + Next Op chosen from 'ops'. + """ + # In Python 3, type(next_op_name) == bytes. Calling bytes.decode('ascii') + # returns a str. + next_op_name = sess.run(self._next_op_name).decode('ascii') + return self._ops_by_name[next_op_name] diff --git a/tensorflow/contrib/kfac/python/ops/op_queue_lib.py b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..09c9a4ab3337f5887da584eec96f230878d43a92 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/op_queue_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper for choosing which op to run next in a distributed setting.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.op_queue import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + 'OpQueue', +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer.py b/tensorflow/contrib/kfac/python/ops/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..bfa15e0948c96477d9a79dece985bc4b6dafab6f --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer.py @@ -0,0 +1,435 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint disable=long-line +from tensorflow.contrib.kfac.python.ops import curvature_matrix_vector_products as cmvp +from tensorflow.contrib.kfac.python.ops import estimator as est +# pylint enable=long-line + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variables as tf_variables +from tensorflow.python.training import gradient_descent + + +class KfacOptimizer(gradient_descent.GradientDescentOptimizer): + """The KFAC Optimizer (https://arxiv.org/abs/1503.05671).""" + + def __init__( + self, + learning_rate, + cov_ema_decay, + damping, + layer_collection, + momentum=0., + momentum_type="regular", + norm_constraint=None, + name="KFAC",): + """Initializes the KFAC optimizer with the given settings. + + Args: + learning_rate: The base learning rate for the optimizer. Should probably + be set to 1.0 when using momentum_type = 'qmodel', but can still be + set lowered if desired (effectively lowering the trust in the + quadratic model.) + cov_ema_decay: The decay factor used when calculating the covariance + estimate moving averages. + damping: The damping factor used to stabilize training due to errors in + the local approximation with the Fisher information matrix, and to + regularize the update direction by making it closer to the gradient. + (Higher damping means the update looks more like a standard gradient + update - see Tikhonov regularization.) + layer_collection: The layer collection object, which holds the fisher + blocks, kronecker factors, and losses associated with the + graph. The layer_collection cannot be modified after KfacOptimizer's + initialization. + momentum: The momentum value for this optimizer. Only applies when + momentum_type is 'regular' or 'adam'. (Default: 0) + momentum_type: The type of momentum to use in this optimizer, one of + 'regular', 'adam', or 'qmodel'. (Default: 'regular') + norm_constraint: float or Tensor. If specified, the update is scaled down + so that its approximate squared Fisher norm v^T F v is at most the + specified value. May only be used with momentum type 'regular'. + (Default: None) + name: The name for this optimizer. (Default: 'KFAC') + + Raises: + ValueError: If the momentum type is unsupported. + ValueError: If clipping is used with momentum type other than 'regular'. + ValueError: If no losses have been registered with layer_collection. + ValueError: If momentum is non-zero and momentum_type is not 'regular' + or 'adam'. + """ + + # We may consider determining the set of variables some other way, but for + # now it's just all the trainable variables. + variables = tf_variables.trainable_variables() + + self._fisher_est = est.FisherEstimator(variables, cov_ema_decay, damping, + layer_collection) + + momentum_type = momentum_type.lower() + legal_momentum_types = ["regular", "adam", "qmodel"] + + if momentum_type not in legal_momentum_types: + raise ValueError("Unsupported momentum type {}. Must be one of {}." + .format(momentum_type, legal_momentum_types)) + if momentum_type != "regular" and norm_constraint is not None: + raise ValueError("Update clipping is only supported with momentum" + "type 'regular'.") + if momentum_type not in ["regular", "adam"] and momentum != 0: + raise ValueError("Momentum must be unspecified if using a momentum_type " + "other than 'regular' or 'adam'.") + + self._momentum = ops.convert_to_tensor(momentum, name="momentum") + self._momentum_type = momentum_type + self._norm_constraint = norm_constraint + + # this is a bit of a hack + # TODO(duckworthd): Handle this in a better way (e.g. pass it in?) + self._batch_size = array_ops.shape(layer_collection.losses[0].inputs)[0] + self._losses = layer_collection.losses + + self.cov_update_op = self._fisher_est.cov_update_op + self.inv_update_op = self._fisher_est.inv_update_op + self.inv_updates_dict = self._fisher_est.inv_updates_dict + + super(KfacOptimizer, self).__init__(learning_rate, name=name) + + @property + def variables(self): + return self._fisher_est.variables + + @property + def damping(self): + return self._fisher_est.damping + + def minimize(self, *args, **kwargs): + + if "var_list" not in kwargs: + kwargs["var_list"] = tf_variables.trainable_variables() + + if set(kwargs["var_list"]) != set(self.variables): + raise ValueError("var_list doesn't match with set of Fisher-estimating " + "variables.") + + return super(KfacOptimizer, self).minimize(*args, **kwargs) + + def apply_gradients(self, grads_and_vars, *args, **kwargs): + """Applies gradients to variables. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + *args: Additional arguments for super.apply_gradients. + **kwargs: Additional keyword arguments for super.apply_gradients. + + Returns: + An `Operation` that applies the specified gradients. + """ + # In Python 3, grads_and_vars can be a zip() object which can only be + # iterated over once. By converting it to a list, we ensure that it can be + # iterated over more than once. + grads_and_vars = list(grads_and_vars) + + # Compute step. + steps_and_vars = self._compute_update_steps(grads_and_vars) + + # Update trainable variables with this step. + return super(KfacOptimizer, self).apply_gradients(steps_and_vars, *args, + **kwargs) + + def _squared_fisher_norm(self, grads_and_vars, precon_grads_and_vars): + """Computes the squared (approximate) Fisher norm of the updates. + + This is defined as v^T F v, where F is the approximate Fisher matrix + as computed by the estimator, and v = F^{-1} g, where g is the gradient. + This is computed efficiently as v^T g. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the squared norm. + + Raises: + ValueError: if the two list arguments do not contain the same variables, + in the same order. + """ + for (_, gvar), (_, pgvar) in zip(grads_and_vars, precon_grads_and_vars): + if gvar is not pgvar: + raise ValueError("The variables referenced by the two arguments " + "must match.") + terms = [ + math_ops.reduce_sum(grad * pgrad) + for (grad, _), (pgrad, _) in zip(grads_and_vars, precon_grads_and_vars) + ] + return math_ops.reduce_sum(terms) + + def _update_clip_coeff(self, grads_and_vars, precon_grads_and_vars): + """Computes the scale factor for the update to satisfy the norm constraint. + + Defined as min(1, sqrt(c / r^T F r)), where c is the norm constraint, + F is the approximate Fisher matrix, and r is the update vector, i.e. + -alpha * v, where alpha is the learning rate, and v is the preconditioned + gradient. + + This is based on Section 5 of Ba et al., Distributed Second-Order + Optimization using Kronecker-Factored Approximations. Note that they + absorb the learning rate alpha (which they denote eta_max) into the formula + for the coefficient, while in our implementation, the rescaling is done + before multiplying by alpha. Hence, our formula differs from theirs by a + factor of alpha. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + Scalar representing the coefficient which should be applied to the + preconditioned gradients to satisfy the norm constraint. + """ + sq_norm_grad = self._squared_fisher_norm(grads_and_vars, + precon_grads_and_vars) + sq_norm_up = sq_norm_grad * self._learning_rate**2 + return math_ops.minimum(1., + math_ops.sqrt(self._norm_constraint / sq_norm_up)) + + def _clip_updates(self, grads_and_vars, precon_grads_and_vars): + """Rescales the preconditioned gradients to satisfy the norm constraint. + + Rescales the preconditioned gradients such that the resulting update r + (after multiplying by the learning rate) will satisfy the norm constraint. + This constraint is that r^T F r <= C, where F is the approximate Fisher + matrix, and C is the norm_constraint attribute. See Section 5 of + Ba et al., Distributed Second-Order Optimization using Kronecker-Factored + Approximations. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + precon_grads_and_vars: List of (preconditioned gradient, variable) pairs. + Must be the result of calling `self._fisher_est.multiply_inverse` + on `grads_and_vars`. + + Returns: + List of (rescaled preconditioned gradient, variable) pairs. + """ + coeff = self._update_clip_coeff(grads_and_vars, precon_grads_and_vars) + return [(pgrad * coeff, var) for pgrad, var in precon_grads_and_vars] + + def _compute_qmodel_hyperparams(self, precon_grads, prev_updates, grads, + variables): + """Compute optimal update hyperparameters from the quadratic model. + + More specifically, if L is the loss we minimize a quadratic approximation + of L(theta + d) which we denote by qmodel(d) with + d = alpha*precon_grad + mu*prev_update with respect to alpha and mu, where + + qmodel(d) = (1/2) * d^T * B * d + grad^T*d + L(theta) . + + Unlike in the KL clipping approach we use the non-approximated quadratic + model where the curvature matrix C is the true Fisher on the current + mini-batch (computed without any approximations beyond mini-batch sampling), + with the usual Tikhonov damping/regularization applied, + + C = F + damping * I + + See Section 7 of https://arxiv.org/abs/1503.05671 for a derivation of + the formula. See Appendix C for a discussion of the trick of using + a factorized Fisher matrix to more efficiently compute the required + vector-matrix-vector products. + + Note that the elements of all 4 lists passed to this function must + be in correspondence with each other. + + Args: + precon_grads: List of preconditioned gradients. + prev_updates: List of updates computed at the previous iteration. + grads: List of gradients. + variables: List of variables in the graph that the update will be + applied to. (Note that this function doesn't actually apply the + update.) + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0) + = qmodel(alpha*precon_grad + mu*prev_update) - L(theta). + """ + + cmvpc = cmvp.CurvatureMatrixVectorProductComputer(self._losses, variables) + + # compute the matrix-vector products with the transposed Fisher factor + fft_precon_grads = cmvpc.multiply_fisher_factor_transpose(precon_grads) + fft_prev_updates = cmvpc.multiply_fisher_factor_transpose(prev_updates) + + batch_size = math_ops.cast( + self._batch_size, dtype=fft_precon_grads[0].dtype) + + # compute the entries of the 2x2 matrix + m_11 = (_inner_product_list(fft_precon_grads, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(precon_grads, precon_grads)) + + m_21 = (_inner_product_list(fft_prev_updates, fft_precon_grads) / batch_size + + self.damping * _inner_product_list(prev_updates, precon_grads)) + + m_22 = (_inner_product_list(fft_prev_updates, fft_prev_updates) / batch_size + + self.damping * _inner_product_list(prev_updates, prev_updates)) + + def non_zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given non-zero previous update. + + We solve the full 2x2 linear system. See Martens & Grosse (2015), + Section 7, definition of $\alpha^*$ and $\mu^*$. + + Returns: + (alpha, mu, qmodel_change), where alpha and mu are chosen to optimize + the quadratic model, and + qmodel_change = qmodel(alpha*precon_grad + mu*prev_update) - qmodel(0). + """ + m = ops.convert_to_tensor([[m_11, m_21], [m_21, m_22]]) + + c = ops.convert_to_tensor([[_inner_product_list(grads, precon_grads)], + [_inner_product_list(grads, prev_updates)]]) + + sol = _two_by_two_solve(m, c) + alpha = -sol[0] + mu = -sol[1] + qmodel_change = 0.5 * math_ops.reduce_sum(sol * c) + + return alpha, mu, qmodel_change + + def zero_prevupd_case(): + r"""Computes optimal (alpha, mu) given all-zero previous update. + + The linear system reduces to 1x1. See Martens & Grosse (2015), + Section 6.4, definition of $\alpha^*$. + + Returns: + (alpha, 0.0, qmodel_change), where alpha is chosen to optimize the + quadratic model, and + qmodel_change = qmodel(alpha*precon_grad) - qmodel(0) + """ + m = m_11 + c = _inner_product_list(grads, precon_grads) + + alpha = -c / m + mu = 0.0 + qmodel_change = 0.5 * alpha * c + + return alpha, mu, qmodel_change + + return control_flow_ops.cond( + math_ops.equal(m_22, 0.0), zero_prevupd_case, non_zero_prevupd_case) + + def _compute_update_steps(self, grads_and_vars): + """Computes the update steps for the variables given the gradients. + + Args: + grads_and_vars: List of (gradient, variable) pairs. + + Returns: + An 'Operation that computes the update steps for the given variables. + """ + if self._momentum_type == "regular": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Apply "KL clipping" if asked for. + if self._norm_constraint is not None: + precon_grads_and_vars = self._clip_updates(grads_and_vars, + precon_grads_and_vars) + + # Update the velocity with this and return it as the step. + return self._update_velocities(precon_grads_and_vars, self._momentum) + + elif self._momentum_type == "adam": + # Update velocity. + velocities_and_vars = self._update_velocities(grads_and_vars, + self._momentum) + # Return "preconditioned" velocity vector as the step. + return self._fisher_est.multiply_inverse(velocities_and_vars) + + elif self._momentum_type == "qmodel": + # Compute "preconditioned" gradient. + precon_grads_and_vars = self._fisher_est.multiply_inverse(grads_and_vars) + + # Extract out singleton lists from the tuple-lists + precon_grads = list( + precon_grad for (precon_grad, _) in precon_grads_and_vars) + grads = list(grad for (grad, _) in grads_and_vars) + variables = list(var for (_, var) in grads_and_vars) + # previous updates are the negative velocities (up to scaling by LR) + prev_updates = list(-self._zeros_slot(var, "velocity", self._name) + for var in variables) + + # Compute optimal velocity update parameters according to quadratic model + alpha, mu, _ = self._compute_qmodel_hyperparams( + precon_grads, prev_updates, grads, variables) + + # Update the velocity with precon_grads according to these params + # and return it as the step. + return self._update_velocities( + precon_grads_and_vars, mu, vec_coeff=-alpha) + + def _update_velocities(self, vecs_and_vars, decay, vec_coeff=1.0): + """Updates the velocities of the variables with the given vectors. + + Args: + vecs_and_vars: List of (vector, variable) pairs. + decay: How much to decay the old velocity by. This is often referred to + as the 'momentum constant'. + vec_coeff: Coefficient to apply to the vectors before adding them to the + velocity. + + Returns: + A list of (velocity, var) indicating the new velocity for each var. + """ + + def _update_velocity(vec, var): + velocity = self._zeros_slot(var, "velocity", self._name) + with ops.colocate_with(velocity): + # NOTE(mattjj): read/modify/write race condition not suitable for async. + + # Compute the new velocity for this variable. + new_velocity = decay * velocity + vec_coeff * vec + + # Save the updated velocity. + return (array_ops.identity(velocity.assign(new_velocity)), var) + + # Go through variable and update its associated part of the velocity vector. + return [_update_velocity(vec, var) for vec, var in vecs_and_vars] + + +def _inner_product_list(list1, list2): + return math_ops.add_n( + [math_ops.reduce_sum(elt1 * elt2) for elt1, elt2 in zip(list1, list2)]) + + +def _two_by_two_solve(m, c): + # it might be better just to crank out the exact formula for 2x2 inverses + return math_ops.matmul(linalg_ops.matrix_inverse(m), c) diff --git a/tensorflow/contrib/kfac/python/ops/optimizer_lib.py b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..87d1866e06bb0a572033828dd5c2f04b05296039 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/optimizer_lib.py @@ -0,0 +1,30 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The KFAC optimizer.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.optimizer import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "KfacOptimizer", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/kfac/python/ops/utils.py b/tensorflow/contrib/kfac/python/ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0fd7f5147739f0f46d2ab6a1c284c6dc75f53cc2 --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils.py @@ -0,0 +1,287 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients_impl +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops + + +# Method used for inverting matrices. +POSDEF_INV_METHOD = "cholesky" + + +def set_global_constants(posdef_inv_method=None): + """Sets various global constants used by the classes in this module.""" + global POSDEF_INV_METHOD + + if posdef_inv_method is not None: + POSDEF_INV_METHOD = posdef_inv_method + + +class SequenceDict(object): + """A dict convenience wrapper that allows getting/setting with sequences.""" + + def __init__(self, iterable=None): + self._dict = dict(iterable or []) + + def __getitem__(self, key_or_keys): + if isinstance(key_or_keys, (tuple, list)): + return list(map(self.__getitem__, key_or_keys)) + else: + return self._dict[key_or_keys] + + def __setitem__(self, key_or_keys, val_or_vals): + if isinstance(key_or_keys, (tuple, list)): + for key, value in zip(key_or_keys, val_or_vals): + self[key] = value + else: + self._dict[key_or_keys] = val_or_vals + + def items(self): + return list(self._dict.items()) + + +def setdefault(dct, key, thunk): + """Like dict.setdefault but delays evaluation of the value to be set.""" + if key not in dct: + dct[key] = thunk() + return dct[key] + + +def tensors_to_column(tensors): + """Converts a tensor or list of tensors to a column vector. + + Args: + tensors: A tensor or list of tensors. + + Returns: + The tensors reshaped into vectors and stacked on top of each other. + """ + if isinstance(tensors, (tuple, list)): + return array_ops.concat( + tuple(array_ops.reshape(tensor, [-1, 1]) for tensor in tensors), axis=0) + else: + return array_ops.reshape(tensors, [-1, 1]) + + +def column_to_tensors(tensors_template, colvec): + """Converts a column vector back to the shape of the given template. + + Args: + tensors_template: A tensor or list of tensors. + colvec: A 2d column vector with the same shape as the value of + tensors_to_column(tensors_template). + + Returns: + X, where X is tensor or list of tensors with the properties: + 1) tensors_to_column(X) = colvec + 2) X (or its elements) have the same shape as tensors_template (or its + elements) + """ + if isinstance(tensors_template, (tuple, list)): + offset = 0 + tensors = [] + for tensor_template in tensors_template: + sz = np.prod(tensor_template.shape.as_list(), dtype=np.int32) + tensor = array_ops.reshape(colvec[offset:(offset + sz)], + tensor_template.shape) + tensors.append(tensor) + offset += sz + + tensors = tuple(tensors) + else: + tensors = array_ops.reshape(colvec, tensors_template.shape) + + return tensors + + +def kronecker_product(mat1, mat2): + """Computes the Kronecker product two matrices.""" + m1, n1 = mat1.get_shape().as_list() + mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) + m2, n2 = mat2.get_shape().as_list() + mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) + return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) + + +def layer_params_to_mat2d(vector): + """Converts a vector shaped like layer parameters to a 2D matrix. + + In particular, we reshape the weights/filter component of the vector to be + 2D, flattening all leading (input) dimensions. If there is a bias component, + we concatenate it to the reshaped weights/filter component. + + Args: + vector: A Tensor or pair of Tensors shaped like layer parameters. + + Returns: + A 2D Tensor with the same coefficients and the same output dimension. + """ + if isinstance(vector, (tuple, list)): + w_part, b_part = vector + w_part_reshaped = array_ops.reshape(w_part, + [-1, w_part.shape.as_list()[-1]]) + return array_ops.concat( + (w_part_reshaped, array_ops.reshape(b_part, [1, -1])), axis=0) + else: + return array_ops.reshape(vector, [-1, vector.shape.as_list()[-1]]) + + +def mat2d_to_layer_params(vector_template, mat2d): + """Converts a canonical 2D matrix representation back to a vector. + + Args: + vector_template: A Tensor or pair of Tensors shaped like layer parameters. + mat2d: A 2D Tensor with the same shape as the value of + layer_params_to_mat2d(vector_template). + + Returns: + A Tensor or pair of Tensors with the same coefficients as mat2d and the same + shape as vector_template. + """ + if isinstance(vector_template, (tuple, list)): + w_part, b_part = mat2d[:-1], mat2d[-1] + return array_ops.reshape(w_part, vector_template[0].shape), b_part + else: + return array_ops.reshape(mat2d, vector_template.shape) + + +def compute_pi(left_factor, right_factor): + """Computes the scalar constant pi for Tikhonov regularization/damping. + + pi = sqrt( (trace(A) / dim(A)) / (trace(B) / dim(B)) ) + See section 6.3 of https://arxiv.org/pdf/1503.05671.pdf for details. + + Args: + left_factor: The left Kronecker factor Tensor. + right_factor: The right Kronecker factor Tensor. + + Returns: + The computed scalar constant pi for these Kronecker Factors (as a Tensor). + """ + # Instead of dividing by the dim of the norm, we multiply by the dim of the + # other norm. This works out the same in the ratio. + left_norm = math_ops.trace(left_factor) * right_factor.get_shape().as_list()[ + 0] + right_norm = math_ops.trace(right_factor) * left_factor.get_shape().as_list()[ + 0] + return math_ops.sqrt(left_norm / right_norm) + + +def posdef_inv(tensor, damping): + """Computes the inverse of tensor + damping * identity.""" + identity = linalg_ops.eye(tensor.shape.as_list()[0], dtype=tensor.dtype) + damping = math_ops.cast(damping, dtype=tensor.dtype) + return posdef_inv_funcs[POSDEF_INV_METHOD](tensor, identity, damping) + + +def posdef_inv_matrix_inverse(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) directly.""" + return linalg_ops.matrix_inverse(tensor + damping * identity) + + +def posdef_inv_cholesky(tensor, identity, damping): + """Computes inverse(tensor + damping * identity) with Cholesky.""" + chol = linalg_ops.cholesky(tensor + damping * identity) + return linalg_ops.cholesky_solve(chol, identity) + + +posdef_inv_funcs = { + "matrix_inverse": posdef_inv_matrix_inverse, + "cholesky": posdef_inv_cholesky, +} + + +class SubGraph(object): + """Defines a subgraph given by all the dependencies of a given set of outputs. + """ + + def __init__(self, outputs): + self._members = set() + + self._recurse_add(outputs) + + def _recurse_add(self, nodes): + for node in nodes: + if node in self._members: + continue + self._members.add(node) + + if isinstance(node, ops.Tensor): + self._recurse_add((node.op,)) + elif isinstance(node, ops.Operation): + self._recurse_add(node.inputs) + + def is_member(self, node): + """Check if 'node' is in this subgraph.""" + return node in self._members + + def variable_uses(self, var): + """Computes number of times a variable is used.""" + return len(self._members.intersection(set(var.value().consumers()))) + + def filter_list(self, node_list): + """Filters 'node_list' to nodes in this subgraph.""" + filtered_list = [] + for node in node_list: + if self.is_member(node): + filtered_list.append(node) + return filtered_list + + +def generate_random_signs(shape, dtype=dtypes.float32): + """Generate a random tensor with {-1, +1} entries.""" + ints = random_ops.random_uniform(shape, maxval=2, dtype=dtypes.int32) + return 2 * math_ops.cast(ints, dtype=dtype) - 1 + + +def fwd_gradients(ys, xs, grad_xs=None, stop_gradients=None): + """Compute forward-mode gradients.""" + # See b/37888268. + + # This version of forward-mode autodiff is based on code by Tim Cooijmans + # and handles list arguments and certain special cases such as when the + # ys doesn't depend on one or more of the xs, and when ops.IndexedSlices are + # generated by the first gradients_impl.gradients call. + + us = [array_ops.zeros_like(y) + float("nan") for y in ys] + dydxs = gradients_impl.gradients(ys, xs, grad_ys=us, + stop_gradients=stop_gradients) + + # Deal with strange types that gradients_impl.gradients returns but can't + # deal with. + dydxs = [ + ops.convert_to_tensor(dydx) + if isinstance(dydx, ops.IndexedSlices) else dydx for dydx in dydxs + ] + dydxs = [ + array_ops.zeros_like(x) if dydx is None else dydx + for x, dydx in zip(xs, dydxs) + ] + + dysdx = gradients_impl.gradients(dydxs, us, grad_ys=grad_xs) + + return dysdx diff --git a/tensorflow/contrib/kfac/python/ops/utils_lib.py b/tensorflow/contrib/kfac/python/ops/utils_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..ddbb4485ce6967082f1844c6d798c078f1cc303b --- /dev/null +++ b/tensorflow/contrib/kfac/python/ops/utils_lib.py @@ -0,0 +1,44 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,line-too-long,wildcard-import +from tensorflow.contrib.kfac.python.ops.utils import * +from tensorflow.python.util.all_util import remove_undocumented +# pylint: enable=unused-import,line-too-long,wildcard-import + +_allowed_symbols = [ + "SequenceDict", + "setdefault", + "tensors_to_column", + "column_to_tensors", + "kronecker_product", + "layer_params_to_mat2d", + "mat2d_to_layer_params", + "compute_pi", + "posdef_inv", + "posdef_inv_matrix_inverse", + "posdef_inv_cholesky", + "posdef_inv_funcs", + "SubGraph", + "generate_random_signs", + "fwd_gradients", +] + +remove_undocumented(__name__, allowed_exception_list=_allowed_symbols) diff --git a/tensorflow/contrib/labeled_tensor/BUILD b/tensorflow/contrib/labeled_tensor/BUILD index 4eba29caecbddc408d168158daf8377aedab7bcc..894e6f6946bb59810a9da2d304cc0dd43d25201d 100644 --- a/tensorflow/contrib/labeled_tensor/BUILD +++ b/tensorflow/contrib/labeled_tensor/BUILD @@ -109,9 +109,9 @@ py_test( ":test_util", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", ], ) diff --git a/tensorflow/contrib/labeled_tensor/README.md b/tensorflow/contrib/labeled_tensor/README.md index 50c6750fd05f1dd605505011f74e23cb84eaf0b0..adce979e2acd1516d19572e31af1fd7b1c7a225c 100644 --- a/tensorflow/contrib/labeled_tensor/README.md +++ b/tensorflow/contrib/labeled_tensor/README.md @@ -3,6 +3,65 @@ LabeledTensor is a library for adding semantically meaningful dimension and coordinate labels to tensors in Tensorflow. -Maintainers: +LabeledTensor was inspired by [xarray](http://xarray.pydata.org) and +[pandas](http://pandas.pydata.org), projects that adds labels to NumPy array. + +## Data model + +`LabeledTensor` is an immutable object consisting of two components: + +- `tensor`: the `tf.Tensor` object containing the labeled tensor's data. +- `axes`: an OrderedDict-like object with keys given by axis names (e.g., + ``"channel"``) and values given by `Axis` objects. + +`Axis` objects keep track of the size of a dimension and, optionally, coordinate +labels along that axis (e.g., `("red", "green", "blue")`) in the form of a +tuple stored in `Axis.labels`. + +Operations on `LabeledTensors` use, preserve and transform axis names and +labels. + +## Quick start + +Try out the following snippet in a script or Jupyter notebook: + + import tensorflow as tf + + lt = tf.contrib.labeled_tensor + + # Create two LabeledTensors: + raw_image = tf.ones((299, 299, 3)) + axes = ['row', 'column', ('channel', ['red', 'green', 'blue'])] + image = lt.LabeledTensor(raw_image, axes) + assert image.tensor is raw_image + weights = lt.LabeledTensor(tf.constant([0.1, 0.3, 0.6]), + [image.axes['channel']]) + + # Examples of valid operations: + lt.transpose(image, ['column', 'row', 'channel']) + lt.reshape(image, ['row', 'column'], ['pixel']) + lt.concat([image, image], 'row') + lt.reduce_sum(image, ['channel']) + lt.select(image, {'channel': 'red'}) + lt.cast(image / 256.0, tf.uint8) + image * weights + lt.matmul(image[0, :, :], weights) + tf.cos(image) # automatically converts to tf.Tensor + +## Adding a custom op + +LabeledTensor has wrappers for [quite a +few](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/labeled_tensor/__init__.py) +TensorFlow ops. + +To easily add your own, you can use the `define_unary_op`, `define_binary_op` +and `define_reduce_op` functions, e.g., + + log = lt.define_unary_op('log', tf.log) + +## Questions + +Please reach out to the authors: + - Stephan Hoyer (shoyer@google.com, github.com/shoyer) - Eric Christiansen (ericmc@google.com, github.com/emchristiansen) diff --git a/tensorflow/contrib/layers/BUILD b/tensorflow/contrib/layers/BUILD index bbb4fb1f57b54848e538d0cd1fad90ce0b6feab0..2f1f283811b6cb9e8bfb52ab2052afac1de700cb 100644 --- a/tensorflow/contrib/layers/BUILD +++ b/tensorflow/contrib/layers/BUILD @@ -88,17 +88,21 @@ tf_custom_op_py_library( "//tensorflow/python:clip_ops", "//tensorflow/python:common_shapes", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:layers", + "//tensorflow/python:layers_base", + "//tensorflow/python:linalg_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:standard_ops", @@ -109,6 +113,7 @@ tf_custom_op_py_library( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", "//tensorflow/python/feature_column", "@six_archive//:six", ], @@ -153,10 +158,10 @@ py_test( deps = [ ":layers_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", + "//tensorflow/python:session", "//third_party/py/numpy", ], ) @@ -168,9 +173,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -238,6 +243,7 @@ py_test( ":layers_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lookup_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", @@ -280,9 +286,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:variables", ], ) @@ -294,9 +300,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":layers_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//third_party/py/numpy", diff --git a/tensorflow/contrib/layers/__init__.py b/tensorflow/contrib/layers/__init__.py index d8ab7c2d70d8a7346c04d326f3a51b40a4f900ea..d309ba958ded86afdc1e4bba2ff471a5181cda4e 100644 --- a/tensorflow/contrib/layers/__init__.py +++ b/tensorflow/contrib/layers/__init__.py @@ -47,6 +47,7 @@ See the @{$python/contrib.layers} guide. @@separable_conv2d @@separable_convolution2d @@softmax +@@spatial_softmax @@stack @@unit_norm @@bow_encoder diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py index a5da0289f497b580a86ee3ecf959b86c866c269a..ad4a0b302fb5f65a50359fe3211f61211b2be63e 100644 --- a/tensorflow/contrib/layers/python/layers/layers.py +++ b/tensorflow/contrib/layers/python/layers/layers.py @@ -198,23 +198,23 @@ def avg_pool3d(inputs, return utils.collect_named_outputs(outputs_collections, sc, outputs) -def _fused_batch_norm( - inputs, - decay=0.999, - center=True, - scale=False, - epsilon=0.001, - activation_fn=None, - param_initializers=None, - updates_collections=ops.GraphKeys.UPDATE_OPS, - is_training=True, - reuse=None, - variables_collections=None, - outputs_collections=None, - trainable=True, - data_format=DATA_FORMAT_NHWC, - zero_debias_moving_mean=False, - scope=None): +def _fused_batch_norm(inputs, + decay=0.999, + center=True, + scale=False, + epsilon=0.001, + activation_fn=None, + param_initializers=None, + param_regularizers=None, + updates_collections=ops.GraphKeys.UPDATE_OPS, + is_training=True, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + data_format=DATA_FORMAT_NHWC, + zero_debias_moving_mean=False, + scope=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -257,6 +257,7 @@ def _fused_batch_norm( maintain a linear activation. param_initializers: Optional initializers for beta, gamma, moving mean and moving variance. + param_regularizers: Optional regularizer for beta and gamma. updates_collections: Collections to collect the update ops for computation. The updates_ops need to be executed with the train_op. If None, a control dependency would be added to make sure the updates are @@ -285,6 +286,7 @@ def _fused_batch_norm( ValueError: If the rank of `inputs` is neither 2 or 4. ValueError: If rank or `C` dimension of `inputs` is undefined. """ + # TODO(reedwm): Add support for fp16 inputs. if data_format not in (DATA_FORMAT_NCHW, DATA_FORMAT_NHWC): raise ValueError('data_format has to be either NCHW or NHWC.') with variable_scope.variable_scope( @@ -323,6 +325,11 @@ def _fused_batch_norm( 'beta') if not param_initializers: param_initializers = {} + if not param_regularizers: + param_regularizers = {} + beta_regularizer = param_regularizers.get('beta') + gamma_regularizer = param_regularizers.get('gamma') + if center: beta_initializer = param_initializers.get('beta', init_ops.zeros_initializer()) @@ -331,6 +338,7 @@ def _fused_batch_norm( shape=params_shape, dtype=dtype, initializer=beta_initializer, + regularizer=beta_regularizer, collections=beta_collections, trainable=trainable_beta) else: @@ -346,6 +354,7 @@ def _fused_batch_norm( shape=params_shape, dtype=dtype, initializer=gamma_initializer, + regularizer=gamma_regularizer, collections=gamma_collections, trainable=trainable) else: @@ -462,7 +471,8 @@ def batch_norm(inputs, scope=None, renorm=False, renorm_clipping=None, - renorm_decay=0.99): + renorm_decay=0.99, + adjustment=None): """Adds a Batch Normalization layer from http://arxiv.org/abs/1502.03167. "Batch Normalization: Accelerating Deep Network Training by Reducing @@ -545,6 +555,17 @@ def batch_norm(inputs, and should be neither too small (which would add noise) nor too large (which would give stale estimates). Note that `decay` is still applied to get the means and variances for inference. + adjustment: A function taking the `Tensor` containing the (dynamic) shape of + the input tensor and returning a pair (scale, bias) to apply to the + normalized values (before gamma and beta), only during training. For + example, + `adjustment = lambda shape: ( + tf.random_uniform(shape[-1:], 0.93, 1.07), + tf.random_uniform(shape[-1:], -0.1, 0.1))` + will scale the normalized value by up to 7% up or down, then shift the + result by up to 0.1 (with independent scaling and bias for each feature + but shared across all examples), and finally apply gamma and/or beta. If + `None`, no adjustment is applied. Returns: A `Tensor` representing the output of the operation. @@ -568,7 +589,10 @@ def batch_norm(inputs, # implementation in normalization_layers.BatchNormalization. inputs = ops.convert_to_tensor(inputs) rank = inputs.get_shape().ndims - possible_to_fuse = batch_weights is None and not renorm and rank in [2, 4] + possible_to_fuse = (batch_weights is None and + not renorm and + rank in [2, 4] and + adjustment is None) if fused and possible_to_fuse and ( zero_debias_moving_mean or rank == 2 or updates_collections is not ops.GraphKeys.UPDATE_OPS): @@ -580,6 +604,7 @@ def batch_norm(inputs, epsilon=epsilon, activation_fn=activation_fn, param_initializers=param_initializers, + param_regularizers=param_regularizers, updates_collections=updates_collections, is_training=is_training, reuse=reuse, @@ -635,6 +660,7 @@ def batch_norm(inputs, renorm=renorm, renorm_clipping=renorm_clipping, renorm_momentum=renorm_decay, + adjustment=adjustment, name=sc.name, _scope=sc, _reuse=reuse, @@ -1250,7 +1276,7 @@ def convolution2d_transpose( # Add variables to collections. _add_variable_to_collections(layer.kernel, variables_collections, 'weights') - if layer.bias: + if layer.bias is not None: _add_variable_to_collections(layer.bias, variables_collections, 'biases') if normalizer_fn is not None: @@ -1359,7 +1385,7 @@ def convolution3d_transpose( # Add variables to collections. _add_variable_to_collections(layer.kernel, variables_collections, 'weights') - if layer.bias: + if layer.bias is not None: _add_variable_to_collections(layer.bias, variables_collections, 'biases') if normalizer_fn is not None: @@ -1731,13 +1757,14 @@ class GDN(base.Layer): trainable=True, name=None, **kwargs): - super(GDN, self).__init__(trainable=trainable, name=name, **kwargs) + super(GDN, self).__init__(trainable=trainable, name=name, + activity_regularizer=activity_regularizer, + **kwargs) self.inverse = inverse self._beta_min = beta_min self._gamma_init = gamma_init self._reparam_offset = reparam_offset self.data_format = data_format - self.activity_regularizer = activity_regularizer self._channel_axis() # trigger ValueError early self.input_spec = base.InputSpec(min_ndim=3, max_ndim=5) @@ -1990,7 +2017,7 @@ def layer_norm(inputs, Given a tensor `inputs` of rank `R`, moments are calculated and normalization is performed over axes `begin_norm_axis ... R - 1`. Scaling and centering, - if requested, is performed over axes `begin_shift_axis .. R - 1`. + if requested, is performed over axes `begin_params_axis .. R - 1`. By default, `begin_norm_axis = 1` and `begin_params_axis = -1`, meaning that normalization is performed over all but the first axis @@ -2504,7 +2531,7 @@ def separable_convolution2d( variables_collections, 'weights') _add_variable_to_collections(layer.pointwise_kernel, variables_collections, 'weights') - if layer.bias: + if layer.bias is not None: _add_variable_to_collections(layer.bias, variables_collections, 'biases') diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py index 1040ad3ca7a4bbd56584f8e2cb8b2a2c8029d418..2837a3172da4758e33ebe5cf97da6d5bc0fb39b9 100644 --- a/tensorflow/contrib/layers/python/layers/layers_test.py +++ b/tensorflow/contrib/layers/python/layers/layers_test.py @@ -1784,29 +1784,41 @@ class BatchNormTest(test.TestCase): def testCreateOpFused(self): self._testCreateOp(True) - def testCreateOpBetaRegularizer(self): + def _testCreateOpBetaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') - _layers.batch_norm(images, param_regularizers={'beta': reg}) + _layers.batch_norm(images, param_regularizers={'beta': reg}, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) beta_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(beta_decay.op.name, 'BatchNorm/beta/Regularizer/mul') - def testCreateOpGammaRegularizer(self): + def testCreateOpBetaRegularizerFused(self): + self._testCreateOpBetaRegularizer(fused=True) + + def testCreateOpBetaRegularizerNonFused(self): + self._testCreateOpBetaRegularizer(fused=False) + + def _testCreateOpGammaRegularizer(self, fused=True): height, width = 3, 3 with self.test_session(): reg = lambda x: 0.1 * math_ops.reduce_sum(x) images = np.random.uniform(size=(5, height, width, 3)).astype('f') _layers.batch_norm( - images, param_regularizers={'gamma': reg}, scale=True) + images, param_regularizers={'gamma': reg}, scale=True, fused=fused) self.assertEqual( len(ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)), 1) gamma_decay = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)[0] self.assertEqual(gamma_decay.op.name, 'BatchNorm/gamma/Regularizer/mul') + def testCreateOpGammaRegularizerFused(self): + self._testCreateOpGammaRegularizer(fused=True) + + def testCreateOpGammaRegularizerNonFused(self): + self._testCreateOpGammaRegularizer(fused=False) + def testCreateVariables(self): height, width = 3, 3 with self.test_session(): @@ -2644,6 +2656,26 @@ class BatchNormTest(test.TestCase): zero_debias_moving_mean=True) sess.run(variables_lib.global_variables_initializer()) + def testAdjustmentCreated(self): + # Tests that the adjustment is appropriately passed to and used by the core + # BN layer. + all_adjustments = [] + def _create_adjustment(shape): + adjustments = [array_ops.ones(shape[-1:]), array_ops.zeros(shape[-1:])] + all_adjustments.extend(adjustments) + return adjustments + depth = 8 + images = array_ops.zeros([10, 5, 5, depth]) + output = _layers.batch_norm( + images, + is_training=True, + adjustment=_create_adjustment) + self.assertListEqual(output.shape.as_list(), images.shape.as_list()) + self.assertEqual(len(all_adjustments), 2) + self.assertListEqual(all_adjustments[0].shape.as_list(), [depth]) + self.assertListEqual(all_adjustments[1].shape.as_list(), [depth]) + + class LayerNormTest(test.TestCase): def testUnknownShape(self): diff --git a/tensorflow/contrib/layers/python/layers/optimizers.py b/tensorflow/contrib/layers/python/layers/optimizers.py index 33db93b9704eb3c81d042e2636f916d5f685ad97..cdceea6fee5bdb5aeb6537ea55d25ccf107def4c 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers.py +++ b/tensorflow/contrib/layers/python/layers/optimizers.py @@ -41,7 +41,7 @@ OPTIMIZER_CLS_NAMES = { "Adagrad": train.AdagradOptimizer, "Adam": train.AdamOptimizer, "Ftrl": train.FtrlOptimizer, - "Momentum": train.MomentumOptimizer, + "Momentum": lambda lr: train.MomentumOptimizer(lr, momentum=0.9), "RMSProp": train.RMSPropOptimizer, "SGD": train.GradientDescentOptimizer, } diff --git a/tensorflow/contrib/layers/python/layers/optimizers_test.py b/tensorflow/contrib/layers/python/layers/optimizers_test.py index 8813a99f1994ade17cca3b1371a17278e434cef9..1ea25bd1a5685eb6f840e621b5739029a660aa0f 100644 --- a/tensorflow/contrib/layers/python/layers/optimizers_test.py +++ b/tensorflow/contrib/layers/python/layers/optimizers_test.py @@ -176,7 +176,7 @@ class OptimizersTest(test.TestCase): session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) # Due to randomness the following number may change if graph is different. - self.assertAlmostEqual(var_value, 8.5591021, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientNoiseWithClipping(self): @@ -193,7 +193,7 @@ class OptimizersTest(test.TestCase): variables.global_variables_initializer().run() session.run(train, feed_dict={x: 5}) var_value, global_step_value = session.run([var, global_step]) - self.assertAlmostEqual(var_value, 9.0, 4) + self.assertAlmostEqual(var_value, 9.86912, 4) self.assertEqual(global_step_value, 1) def testGradientClip(self): diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index f3949beed04655456b3f0b550f5c757c85899270..2917a30a1770351a2315a8deb696d1841d260ff0 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -55,6 +55,7 @@ py_library( "//tensorflow/python:logging_ops", "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:metrics", "//tensorflow/python:nn", "//tensorflow/python:parsing_ops", "//tensorflow/python:partitioned_variables", @@ -76,6 +77,7 @@ py_library( "//tensorflow/python:weights_broadcast_ops", "//tensorflow/python/estimator", "//tensorflow/python/estimator:estimator_py", + "//tensorflow/python/estimator:export_export", "//tensorflow/python/estimator:export_output", "//tensorflow/python/estimator:inputs", "//tensorflow/python/estimator:inputs_queues", @@ -85,6 +87,7 @@ py_library( "//tensorflow/python/estimator:run_config", "//tensorflow/python/feature_column", "//tensorflow/python/feature_column:feature_column_py", + "//tensorflow/python/ops/losses", "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", @@ -131,6 +134,7 @@ py_test( "//tensorflow/contrib/learn/python/learn/datasets", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:string_ops", "//tensorflow/python:training", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -155,10 +159,11 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", @@ -198,6 +203,7 @@ py_test( "//tensorflow/contrib/training:training_py", "//tensorflow/python:client_testlib", "//tensorflow/python:platform", + "//tensorflow/python/estimator:run_config", ], ) @@ -216,6 +222,7 @@ py_test( "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", @@ -278,6 +285,8 @@ py_test( "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", "//tensorflow/python:protos_all_py", + "//tensorflow/python:session", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", @@ -319,12 +328,12 @@ py_test( "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn/python/learn/datasets", - "//tensorflow/contrib/losses:losses_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", + "//tensorflow/python/ops/losses", "//third_party/py/numpy", ], ) @@ -363,10 +372,10 @@ py_test( deps = [ ":learn", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lookup_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//tensorflow/python/ops/losses", @@ -430,7 +439,6 @@ py_test( "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:functional_ops", @@ -439,6 +447,7 @@ py_test( "//tensorflow/python:random_ops", "//tensorflow/python:random_seed", "//tensorflow/python:rnn_cell", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -575,10 +584,10 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python/estimator:export_output", "//tensorflow/python/saved_model:signature_constants", @@ -631,9 +640,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//third_party/py/numpy", ], ) @@ -721,6 +730,7 @@ py_test( "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", ], @@ -768,13 +778,14 @@ py_test( ":learn", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/session_bundle:exporter", - "//tensorflow/contrib/session_bundle:manifest_proto_py", + "//tensorflow/contrib/session_bundle:manifest_proto_py_pb2", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", "//tensorflow/python:random_ops", + "//tensorflow/python:session", "//tensorflow/python:training", "//third_party/py/numpy", "@six_archive//:six", @@ -822,12 +833,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":learn", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", - "//tensorflow/python:platform", - "//tensorflow/python:util", + "//tensorflow/python:dtypes", ], ) @@ -855,7 +863,6 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python", # TODO(b/34059704): remove when fixed "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py index 71a82ccf56f8fe5171c915178868e7bb8da77022..12f9bba531a296a00d17956b8ce32e5d7dead380 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_test.py @@ -348,6 +348,12 @@ class DNNClassifierTest(test.TestCase): for prediction in predictions: self.assertIn(prediction, (0, 1)) + def _assertClassificationPredictions( + self, expected_len, n_classes, predictions): + self.assertEqual(expected_len, len(predictions)) + for prediction in predictions: + self.assertIn(prediction, range(n_classes)) + def _assertProbabilities(self, expected_batch_size, expected_n_classes, probabilities): self.assertEqual(expected_batch_size, len(probabilities)) @@ -732,7 +738,7 @@ class DNNClassifierTest(test.TestCase): self.assertIn('loss', scores) predicted_classes = classifier.predict_classes( input_fn=_input_fn, as_iterable=False) - self._assertBinaryPredictions(3, predicted_classes) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = classifier.predict(input_fn=_input_fn, as_iterable=False) self.assertAllEqual(predicted_classes, predictions) probabilities = classifier.predict_proba( @@ -765,13 +771,14 @@ class DNNClassifierTest(test.TestCase): feature_column.real_valued_column('age') ] + n_classes = 3 classifier = dnn.DNNClassifier( - n_classes=3, + n_classes=n_classes, feature_columns=feature_columns, hidden_units=[3, 3], config=run_config.RunConfig(tf_random_seed=1)) - classifier.fit(input_fn=_input_fn, steps=200) + classifier.fit(input_fn=_input_fn, steps=300) scores = classifier.evaluate(input_fn=_input_fn, steps=1) self._assertInRange(0.0, 1.0, scores['accuracy']) @@ -780,7 +787,7 @@ class DNNClassifierTest(test.TestCase): predicted_classes = list( classifier.predict_classes( input_fn=predict_input_fn, as_iterable=True)) - self.assertListEqual(predicted_classes, [1, 0, 0]) + self._assertClassificationPredictions(3, n_classes, predicted_classes) predictions = list( classifier.predict( input_fn=predict_input_fn, as_iterable=True)) @@ -788,8 +795,7 @@ class DNNClassifierTest(test.TestCase): predicted_proba = list( classifier.predict_proba( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose( - predicted_proba, [[0., 1., 0.], [1., 0., 0.], [1., 0., 0.]], atol=0.3) + self._assertProbabilities(3, n_classes, predicted_proba) def testCustomMetrics(self): """Tests custom evaluation metrics.""" @@ -1214,6 +1220,12 @@ class DNNRegressorTest(test.TestCase): scores = regressor.evaluate(input_fn=_input_fn_eval, steps=1) self.assertIn('loss', scores) + def _assertRegressionOutputs( + self, predictions, expected_shape): + predictions_nparray = np.array(predictions) + self.assertAllEqual(expected_shape, predictions_nparray.shape) + self.assertTrue(np.issubdtype(predictions_nparray.dtype, np.float)) + def testPredict_AsIterableFalse(self): """Tests predict method with as_iterable=False.""" labels = [1., 0., 0.2] @@ -1252,7 +1264,7 @@ class DNNRegressorTest(test.TestCase): self.assertIn('loss', scores) predicted_scores = regressor.predict_scores( input_fn=_input_fn, as_iterable=False) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = regressor.predict(input_fn=_input_fn, as_iterable=False) self.assertAllClose(predicted_scores, predictions) @@ -1296,7 +1308,7 @@ class DNNRegressorTest(test.TestCase): predicted_scores = list( regressor.predict_scores( input_fn=predict_input_fn, as_iterable=True)) - self.assertAllClose(labels, predicted_scores, atol=0.2) + self._assertRegressionOutputs(predicted_scores, [3]) predictions = list( regressor.predict(input_fn=predict_input_fn, as_iterable=True)) self.assertAllClose(predicted_scores, predictions) diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py index 1724d7599d09873f969555cc9382c0753eba463f..69440e823ef1ed2d739f28bc14587891f2de80bb 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator.py @@ -639,7 +639,7 @@ class DynamicRnnEstimator(estimator.Estimator): ValueError: `problem_type` is not one of `ProblemType.LINEAR_REGRESSION` or `ProblemType.CLASSIFICATION`. ValueError: `problem_type` is `ProblemType.CLASSIFICATION` but - `num_classes` is not specifieProblemType + `num_classes` is not specified. ValueError: `prediction_type` is not one of `PredictionType.MULTIPLE_VALUE` or `PredictionType.SINGLE_VALUE`. """ diff --git a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py index d518e38fe0f8ea2fc5eb0b96828ed82703345a0d..c9a11f27f16d63362260b87afc44fee9d81e2efd 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dynamic_rnn_estimator_test.py @@ -700,18 +700,18 @@ class DynamicRNNEstimatorLearningTest(test.TestCase): 'Loss should be less than {}; got {}'.format(loss_threshold, loss)) - def testLearnMajority(self): + def DISABLED_testLearnMajority(self): """Test learning the 'majority' function.""" batch_size = 16 sequence_length = 7 - train_steps = 200 + train_steps = 500 eval_steps = 20 cell_type = 'lstm' cell_size = 4 optimizer_type = 'Momentum' learning_rate = 2.0 momentum = 0.9 - accuracy_threshold = 0.9 + accuracy_threshold = 0.6 def get_majority_input_fn(batch_size, sequence_length, seed=None): random_seed.set_random_seed(seed) diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 234d7318502f0c55a2be4f1256ceb340d905d276..788d2d0b1a58fad16712c968593b40de0d3979f0 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -981,6 +981,7 @@ class BaseEstimator( global_step = training_util.create_global_step(g) features, labels = input_fn() self._check_inputs(features, labels) + training_util._get_or_create_global_step_read() # pylint: disable=protected-access model_fn_ops = self._get_train_ops(features, labels) ops.add_to_collection(ops.GraphKeys.LOSSES, model_fn_ops.loss) all_hooks.extend(hooks) diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 719e5da21df57ed778fe6aee3fe57f3b202dfaa2..bc0e6fc0091c9b5419ab526855b404eb4a927e97 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -33,7 +33,6 @@ from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import logging_ops from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics as metrics_lib @@ -120,7 +119,7 @@ class Head(object): update_op = tf.contrib.layers.optimize_loss(optimizer=sync, loss=model_fn_ops.loss, ...) hooks = [sync.make_session_run_hook(is_chief)] - ... upate train_op and hooks in ModelFnOps and return + ... update train_op and hooks in ModelFnOps and return ``` """ __metaclass__ = abc.ABCMeta @@ -635,10 +634,11 @@ def _create_model_fn_ops(features, if (mode != model_fn.ModeKeys.INFER) and (labels is not None): weight_tensor = _weight_tensor(features, weight_column_name) loss, weighted_average_loss = loss_fn(labels, logits, weight_tensor) - # Uses the deprecated API to set the tag explicitly. - # Without it, training and eval losses will show up in different graphs. - logging_ops.scalar_summary( - _summary_key(head_name, mkey.LOSS), weighted_average_loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(weighted_average_loss) + summary.scalar(_summary_key(head_name, mkey.LOSS), summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: @@ -1484,8 +1484,12 @@ class _LossOnlyHead(Head): loss = self._loss_fn() if isinstance(loss, list): loss = math_ops.add_n(loss) - logging_ops.scalar_summary( - _summary_key(self.head_name, mkey.LOSS), loss) + # The name_scope escapism is needed to maintain the same summary tag + # after switching away from the now unsupported API. + with ops.name_scope(""): + summary_loss = array_ops.identity(loss) + summary.scalar(_summary_key(self.head_name, mkey.LOSS), + summary_loss) if mode == model_fn.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError("train_op_fn can not be None in TRAIN mode") @@ -2029,13 +2033,13 @@ def _streaming_accuracy_at_threshold(predictions, labels, weights, threshold): def _streaming_precision_at_threshold(predictions, labels, weights, threshold): precision_tensor, update_op = metrics_lib.precision_at_thresholds( - labels, predictions, (threshold,),_float_weights_or_none(weights)) + labels, predictions, (threshold,), _float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) def _streaming_recall_at_threshold(predictions, labels, weights, threshold): precision_tensor, update_op = metrics_lib.recall_at_thresholds( - labels, predictions, (threshold,),_float_weights_or_none(weights)) + labels, predictions, (threshold,), _float_weights_or_none(weights)) return array_ops.squeeze(precision_tensor), array_ops.squeeze(update_op) diff --git a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py index a92302420f1c293b086942977360d624b9e06db6..992b804f59ecd88fedc2fba10d3079f93c4fe83d 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/kmeans.py +++ b/tensorflow/contrib/learn/python/learn/estimators/kmeans.py @@ -12,7 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Implementation of k-means clustering on top of `Estimator` API.""" +"""Implementation of k-means clustering on top of `Estimator` API. + +This module is deprecated. Please use +@{tf.contrib.factorization.KMeansClustering} instead of +@{tf.contrib.learn.KMeansClustering}. It has a similar interface, but uses the +@{tf.estimator.Estimator} API instead of @{tf.contrib.learn.Estimator}. +""" from __future__ import absolute_import from __future__ import division @@ -29,12 +35,17 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops -from tensorflow.python.summary import summary from tensorflow.python.ops.control_flow_ops import with_dependencies from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import session_run_hook from tensorflow.python.training.session_run_hook import SessionRunArgs +from tensorflow.python.util.deprecation import deprecated + +_USE_TF_CONTRIB_FACTORIZATION = ( + 'Please use tf.contrib.factorization.KMeansClustering instead of' + ' tf.contrib.learn.KMeansClustering. It has a similar interface, but uses' + ' the tf.estimator.Estimator API instead of tf.contrib.learn.Estimator.') class _LossRelativeChangeHook(session_run_hook.SessionRunHook): @@ -153,6 +164,7 @@ class KMeansClustering(estimator.Estimator): ALL_SCORES = 'all_scores' LOSS_OP_NAME = 'kmeans_loss' + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def __init__(self, num_clusters, model_dir=None, @@ -204,6 +216,7 @@ class KMeansClustering(estimator.Estimator): model_dir=model_dir, config=config) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def predict_cluster_idx(self, input_fn=None): """Yields predicted cluster indices.""" key = KMeansClustering.CLUSTER_IDX @@ -212,6 +225,7 @@ class KMeansClustering(estimator.Estimator): for result in results: yield result[key] + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def score(self, input_fn=None, steps=None): """Predict total sum of distances to nearest clusters. @@ -229,6 +243,7 @@ class KMeansClustering(estimator.Estimator): self.evaluate( input_fn=input_fn, steps=steps)[KMeansClustering.SCORES]) + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def transform(self, input_fn=None, as_iterable=False): """Transforms each element to distances to cluster centers. @@ -255,6 +270,7 @@ class KMeansClustering(estimator.Estimator): else: return results + @deprecated(None, _USE_TF_CONTRIB_FACTORIZATION) def clusters(self): """Returns cluster centers.""" return super(KMeansClustering, self).get_variable_value(self.CLUSTERS) diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 8be9c72adf1602826fabc650f350b57f72c886be..44e6c7c52dac524a22e9099e33e2aef82f8fe7ba 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -23,7 +23,6 @@ import collections import six -from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import get_graph_from_inputs from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import metric_key @@ -32,6 +31,7 @@ from tensorflow.python.estimator import model_fn as core_model_fn_lib from tensorflow.python.estimator.export import export_output as core_export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging @@ -156,11 +156,11 @@ class ModelFnOps( else: if isinstance(predictions, dict): predictions = { - k: contrib_framework.convert_to_tensor_or_sparse_tensor(v) + k: sparse_tensor.convert_to_tensor_or_sparse_tensor(v) for k, v in six.iteritems(predictions) } else: - predictions = contrib_framework.convert_to_tensor_or_sparse_tensor( + predictions = sparse_tensor.convert_to_tensor_or_sparse_tensor( predictions) # Validate eval_metric_ops diff --git a/tensorflow/contrib/learn/python/learn/experiment.py b/tensorflow/contrib/learn/python/learn/experiment.py index 9b55826e627c5198ba7f88505afb866a0f308553..307db76afe20a7743df16d169270a6f319497eb6 100644 --- a/tensorflow/contrib/learn/python/learn/experiment.py +++ b/tensorflow/contrib/learn/python/learn/experiment.py @@ -149,16 +149,16 @@ class Experiment(object): Args: estimator: Object implementing Estimator interface, which could be a - combination of ${tf.contrib.learn.Trainable} and - ${tf.contrib.learn.Evaluable} (deprecated), or - ${tf.estimator.`Estimator}. + combination of @{tf.contrib.learn.Trainable} and + @{tf.contrib.learn.Evaluable} (deprecated), or + @{tf.estimator.Estimator}. train_input_fn: function, returns features and labels for training. eval_input_fn: function, returns features and labels for evaluation. If `eval_steps` is `None`, this should be configured only to produce for a finite number of batches (generally, 1 epoch over the evaluation data). eval_metrics: `dict` of string, metric function. If `None`, default set is used. This should be `None` if the `estimator` is - ${tf.estimator.Estimator}. If metrics are provided they will be + @{tf.estimator.Estimator}. If metrics are provided they will be *appended* to the default set. train_steps: Perform this many steps of training. `None`, the default, means train forever. diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py index 4c50d40aaa9b3c5d94d0a66d08e8ab6173db427a..db18ebf05d5fb98e28e767be7bcccdf992a56fd8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder.py @@ -28,13 +28,14 @@ import six from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.platform import tf_logging as logging # pylint: disable=g-multiple-import,g-bad-import-order from .pandas_io import HAS_PANDAS, extract_pandas_data, extract_pandas_matrix, extract_pandas_labels from .dask_io import HAS_DASK, extract_dask_data, extract_dask_labels - # pylint: enable=g-multiple-import,g-bad-import-order @@ -365,8 +366,13 @@ class DataFeeder(object): self.random_state = np.random.RandomState( 42) if random_state is None else random_state - num_samples = list(self._x.values())[0].shape[ - 0] if x_is_dict else self._x.shape[0] + if x_is_dict: + num_samples = list(self._x.values())[0].shape[0] + elif tensor_util.is_tensor(self._x): + num_samples = self._x.shape[0].value # shape will be a Dimension, extract an int + else: + num_samples = self._x.shape[0] + if self._shuffle: self.indices = self.random_state.permutation(num_samples) else: diff --git a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py index eaf6ae4ed72148436c3d1aa3838b516c6025b0aa..82848be7df653dd60219317d28f233767746f544 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py @@ -42,16 +42,6 @@ class DataFeederTest(test.TestCase): with self.assertRaisesRegexp(TypeError, 'annot convert'): data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) - def test_input_uint32(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - - def test_input_uint64(self): - data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) - self._assert_raises(data) - self._assert_raises(self._wrap_dict(data)) - def _assert_dtype(self, expected_np_dtype, expected_tf_dtype, input_data): feeder = data_feeder.DataFeeder(input_data, None, n_classes=0, batch_size=1) if isinstance(input_data, dict): @@ -87,6 +77,16 @@ class DataFeederTest(test.TestCase): self._assert_dtype(np.int64, dtypes.int64, data) self._assert_dtype(np.int64, dtypes.int64, self._wrap_dict(data)) + def test_input_uint32(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint32) + self._assert_dtype(np.uint32, dtypes.uint32, data) + self._assert_dtype(np.uint32, dtypes.uint32, self._wrap_dict(data)) + + def test_input_uint64(self): + data = np.matrix([[1, 2], [3, 4]], dtype=np.uint64) + self._assert_dtype(np.uint64, dtypes.uint64, data) + self._assert_dtype(np.uint64, dtypes.uint64, self._wrap_dict(data)) + def test_input_uint8(self): data = np.matrix([[1, 2], [3, 4]], dtype=np.uint8) self._assert_dtype(np.uint8, dtypes.uint8, data) diff --git a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py index bdb88b89bb3dba95a229724994874b0a26b1fc3f..4b34fc62849766370979bb2002d42ee03ea7161a 100644 --- a/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py +++ b/tensorflow/contrib/learn/python/learn/learn_io/graph_io.py @@ -442,7 +442,8 @@ def read_keyed_batch_features(file_pattern, feature_queue_capacity=100, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -482,6 +483,8 @@ def read_keyed_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: Returns tuple of: @@ -493,6 +496,7 @@ def read_keyed_batch_features(file_pattern, """ with ops.name_scope(name, 'read_batch_features', [file_pattern]) as scope: + if read_batch_size is None: read_batch_size = batch_size keys, examples = read_keyed_batch_examples( file_pattern, batch_size, @@ -501,7 +505,7 @@ def read_keyed_batch_features(file_pattern, num_epochs=num_epochs, queue_capacity=queue_capacity, num_threads=reader_num_threads, - read_batch_size=batch_size, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=scope) # Parse the example. @@ -727,7 +731,8 @@ def read_batch_features(file_pattern, reader_num_threads=1, num_enqueue_threads=2, parse_fn=None, - name=None): + name=None, + read_batch_size=None): """Adds operations to read, queue, batch and parse `Example` protos. Given file pattern (or list of files), will setup a queue for file names, @@ -768,6 +773,8 @@ def read_batch_features(file_pattern, parse_fn: Parsing function, takes `Example` Tensor returns parsed representation. If `None`, no parsing is done. name: Name of resulting op. + read_batch_size: An int or scalar `Tensor` specifying the number of + records to read at once. If `None`, defaults to `batch_size`. Returns: A dict of `Tensor` or `SparseTensor` objects for each in `features`. @@ -786,6 +793,7 @@ def read_batch_features(file_pattern, reader_num_threads=reader_num_threads, feature_queue_capacity=feature_queue_capacity, num_enqueue_threads=num_enqueue_threads, + read_batch_size=read_batch_size, parse_fn=parse_fn, name=name) return features diff --git a/tensorflow/contrib/learn/python/learn/learn_runner.py b/tensorflow/contrib/learn/python/learn/learn_runner.py index 9f9740ec492e8b71191aff17f70a007409525ccd..2af723a0d64822e81fa0fbeb106ab812de6ab4e8 100644 --- a/tensorflow/contrib/learn/python/learn/learn_runner.py +++ b/tensorflow/contrib/learn/python/learn/learn_runner.py @@ -165,7 +165,7 @@ def run(experiment_fn, output_dir=None, schedule=None, run_config=None, must be None. 2) It accepts two arguments `run_config` and `hparams`, which should be used to create the `Estimator` (`run_config` passed as `config` to its - constructor; `hparams` used as the hyper-paremeters of the model). + constructor; `hparams` used as the hyper-parameters of the model). It must return an `Experiment`. For this case, `output_dir` must be None. output_dir: Base output directory [Deprecated]. schedule: The name of the method in the `Experiment` to run. diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py index a60d548391a58533d1121a47d2e16a646b60df82..b2521933e524e7ec24d73d4b5171f33e507dd88c 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -107,9 +107,8 @@ def build_default_serving_input_fn(features, default_batch_size=None): shape_list[0] = default_batch_size shape = tensor_shape.TensorShape(shape_list) - features_placeholders[name] = array_ops.placeholder(dtype=t.dtype, - shape=shape, - name=t.op.name) + features_placeholders[name] = array_ops.placeholder( + dtype=t.dtype, shape=shape, name=t.op.name) labels = None # these are not known in serving! return InputFnOps(features_placeholders, labels, features_placeholders) return input_fn diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils_test.py index 72deeb57a461083fcb7b79854a94161cade884ea..e9dc6a687594ac014f970e44fc82c090a0a5a4ee 100644 --- a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils_test.py @@ -22,10 +22,11 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.platform import test + class InputFnTest(test.TestCase): def test_build_default_serving_input_fn_name(self): - """Test case for #12755""" + """Test case for issue #12755.""" f = { 'feature': array_ops.placeholder( @@ -35,5 +36,6 @@ class InputFnTest(test.TestCase): v = serving_input() self.assertTrue(isinstance(v, input_fn_utils.InputFnOps)) -if __name__ == "__main__": + +if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py index 676e1f2b51c0a0a48b84f4e1d3d8ad9ae2521f9b..49413092a6bae547ddd2cad272b1abb3af1de046 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -50,6 +50,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import signature_constants from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.training import saver from tensorflow.python.util import compat @@ -108,7 +109,11 @@ def build_standardized_signature_def(input_tensors, output_tensors, classes = _get_classification_classes(output_tensors) scores = _get_classification_scores(output_tensors) if classes is None and scores is None: - (_, classes), = output_tensors.items() + items = list(output_tensors.items()) + if items[0][1].dtype == dtypes.string: + (_, classes), = items + else: + (_, scores), = items return signature_def_utils.classification_signature_def( examples, classes, scores) elif _is_regression_problem(problem_type, input_tensors, output_tensors): @@ -616,7 +621,13 @@ def make_best_model_export_strategy(serving_input_fn, Returns: The string path to the exported directory. """ - + if not checkpoint_path: + # TODO(b/67425018): switch to + # checkpoint_path = estimator.latest_checkpoint() + # as soon as contrib is cleaned up and we can thus be sure that + # estimator is a tf.estimator.Estimator and not a + # tf.contrib.learn.Estimator + checkpoint_path = saver.latest_checkpoint(estimator.model_dir) export_checkpoint_path, export_eval_result = best_model_selector.update( checkpoint_path, eval_result) @@ -629,3 +640,47 @@ def make_best_model_export_strategy(serving_input_fn, return '' return export_strategy.ExportStrategy('best_model', export_fn) + + +# TODO(b/67013778): Revisit this approach when corresponding changes to +# TF Core are finalized. +def extend_export_strategy(base_export_strategy, post_export_fn, + post_export_name): + """Extend ExportStrategy, calling post_export_fn after export. + + Args: + base_export_strategy: An ExportStrategy that can be passed to the Experiment + constructor. + post_export_fn: A user-specified function to call after exporting the + SavedModel. Takes the export directory as an argument, and returns + a string path to a (potentially different) SavedModel. + post_export_name: The directory name under the export base directory where + SavedModels generated by the post_export_fn will be written. + + Returns: + An ExportStrategy that can be passed to the Experiment constructor. + """ + def export_fn(estimator, export_dir_base, checkpoint_path=None): + """Exports the given Estimator as a SavedModel and invokes post_export_fn. + + Args: + estimator: the Estimator to export. + export_dir_base: A string containing a directory to write the exported + graphs and checkpoint. + checkpoint_path: The checkpoint path to export. If None (the default), + the most recent checkpoint found within the model directory is chosen. + + Returns: + The string path to the SavedModel indicated by post_export_fn. + + Raises: + ValueError: If `estimator` is a ${tf.estimator.Estimator} instance + and `default_output_alternative_key` was specified. + """ + export_dir = base_export_strategy.export(estimator, export_dir_base, + checkpoint_path) + if post_export_fn: + export_dir = post_export_fn(export_dir) + return export_dir + + return export_strategy.ExportStrategy(post_export_name, export_fn) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py index 66bca9c0f533dc97c682caf2befd33197eb0a733..27f17b54221ea442baafb382aa3fb034d1bb82e6 100644 --- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -73,7 +73,7 @@ class SavedModelExportUtilsTest(test.TestCase): def test_build_standardized_signature_def_regression(self): input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "output-1": @@ -86,14 +86,16 @@ class SavedModelExportUtilsTest(test.TestCase): expected_signature_def = meta_graph_pb2.SignatureDef() shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) - dtype = types_pb2.DataType.Value("DT_FLOAT") + dtype_float = types_pb2.DataType.Value("DT_FLOAT") + dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.REGRESS_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.REGRESS_OUTPUTS].CopyFrom( - meta_graph_pb2.TensorInfo( - name="output-tensor-1:0", dtype=dtype, tensor_shape=shape)) + meta_graph_pb2.TensorInfo(name="output-tensor-1:0", + dtype=dtype_float, + tensor_shape=shape)) expected_signature_def.method_name = signature_constants.REGRESS_METHOD_NAME self.assertEqual(actual_signature_def, expected_signature_def) @@ -102,7 +104,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests classification with one output tensor.""" input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "output-1": @@ -115,11 +117,10 @@ class SavedModelExportUtilsTest(test.TestCase): expected_signature_def = meta_graph_pb2.SignatureDef() shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) - dtype_float = types_pb2.DataType.Value("DT_FLOAT") dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( @@ -135,7 +136,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests multiple output tensors that include classes and probabilities.""" input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -160,7 +161,7 @@ class SavedModelExportUtilsTest(test.TestCase): dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( @@ -182,7 +183,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests multiple output tensors that include classes and scores.""" input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -206,7 +207,7 @@ class SavedModelExportUtilsTest(test.TestCase): dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_CLASSES].CopyFrom( meta_graph_pb2.TensorInfo( @@ -228,7 +229,7 @@ class SavedModelExportUtilsTest(test.TestCase): """Tests classification without classes tensor.""" input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "probabilities": @@ -246,9 +247,10 @@ class SavedModelExportUtilsTest(test.TestCase): shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") + dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( @@ -268,7 +270,7 @@ class SavedModelExportUtilsTest(test.TestCase): """ input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -289,9 +291,10 @@ class SavedModelExportUtilsTest(test.TestCase): shape = tensor_shape_pb2.TensorShapeProto( dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_float = types_pb2.DataType.Value("DT_FLOAT") + dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs[signature_constants.CLASSIFY_INPUTS].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs[ signature_constants.CLASSIFY_OUTPUT_SCORES].CopyFrom( meta_graph_pb2.TensorInfo( @@ -311,7 +314,7 @@ class SavedModelExportUtilsTest(test.TestCase): """ input_tensors = { "input-1": - array_ops.placeholder(dtypes.float32, 1, name="input-tensor-1") + array_ops.placeholder(dtypes.string, 1, name="input-tensor-1") } output_tensors = { "classes": @@ -330,9 +333,10 @@ class SavedModelExportUtilsTest(test.TestCase): dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) dtype_int64 = types_pb2.DataType.Value("DT_INT64") dtype_float = types_pb2.DataType.Value("DT_FLOAT") + dtype_string = types_pb2.DataType.Value("DT_STRING") expected_signature_def.inputs["input-1"].CopyFrom( meta_graph_pb2.TensorInfo( - name="input-tensor-1:0", dtype=dtype_float, tensor_shape=shape)) + name="input-tensor-1:0", dtype=dtype_string, tensor_shape=shape)) expected_signature_def.outputs["classes"].CopyFrom( meta_graph_pb2.TensorInfo( name="output-tensor-classes:0", @@ -499,13 +503,13 @@ class SavedModelExportUtilsTest(test.TestCase): def test_build_all_signature_defs(self): input_features = constant_op.constant(["10"]) - input_example = constant_op.constant(["11"]) + input_example = constant_op.constant(["input string"]) input_ops = input_fn_utils.InputFnOps({ "features": input_features }, None, {"default input": input_example}) input_alternatives, _ = ( saved_model_export_utils.get_input_alternatives(input_ops)) - output_1 = constant_op.constant(["1"]) + output_1 = constant_op.constant([1.0]) output_2 = constant_op.constant(["2"]) output_3 = constant_op.constant(["3"]) provided_output_alternatives = { @@ -738,6 +742,26 @@ class SavedModelExportUtilsTest(test.TestCase): export_strategy.export(test_estimator, export_dir_base, "fake_ckpt_1", None) + def test_extend_export_strategy(self): + def _base_export_fn(unused_estimator, export_dir_base, + unused_checkpoint_path=None): + return export_dir_base + "/e1" + + def _post_export_fn(orig_path): + return orig_path + "/rewrite" + + base_export_strategy = export_strategy_lib.ExportStrategy( + "Servo", _base_export_fn) + + final_export_strategy = saved_model_export_utils.extend_export_strategy( + base_export_strategy, _post_export_fn, "Servo2") + self.assertEqual(final_export_strategy.name, "Servo2") + + test_estimator = TestEstimator() + final_path = final_export_strategy.export(test_estimator, "/path/to/orig", + "/path/to/checkpoint") + self.assertEqual("/path/to/orig/e1/rewrite", final_path) + def _create_test_export_dir(export_dir_base): export_dir = saved_model_export_utils.get_timestamped_export_dir( diff --git a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py index d4de638338689d2775efe6988af3a058bb128c07..5e7b422e3cc368a22eb94ed470297ae78293c4eb 100644 --- a/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py +++ b/tensorflow/contrib/legacy_seq2seq/python/ops/seq2seq.py @@ -76,7 +76,7 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.util import nest # TODO(ebrevdo): Remove once _linear is fully deprecated. -linear = rnn_cell_impl._linear # pylint: disable=protected-access +Linear = core_rnn_cell._Linear # pylint: disable=protected-access,invalid-name def _extract_argmax_and_embed(embedding, @@ -645,7 +645,7 @@ def attention_decoder(decoder_inputs, query = array_ops.concat(query_list, 1) for a in xrange(num_heads): with variable_scope.variable_scope("Attention_%d" % a): - y = linear(query, attention_vec_size, True) + y = Linear(query, attention_vec_size, True)(query) y = array_ops.reshape(y, [-1, 1, 1, attention_vec_size]) # Attention mask is a softmax of v^T * tanh(...). s = math_ops.reduce_sum(v[a] * math_ops.tanh(hidden_features[a] + y), @@ -679,7 +679,9 @@ def attention_decoder(decoder_inputs, input_size = inp.get_shape().with_rank(2)[1] if input_size.value is None: raise ValueError("Could not infer input size from input: %s" % inp.name) - x = linear([inp] + attns, input_size, True) + + inputs = [inp] + attns + x = Linear(inputs, input_size, True)(inputs) # Run the RNN. cell_output, state = cell(x, state) # Run the attention mechanism. @@ -691,7 +693,8 @@ def attention_decoder(decoder_inputs, attns = attention(state) with variable_scope.variable_scope("AttnOutputProjection"): - output = linear([cell_output] + attns, output_size, True) + inputs = [cell_output] + attns + output = Linear(inputs, output_size, True)(inputs) if loop_function is not None: prev = output outputs.append(output) diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index 810a3d34eee0a886fcf49ca3209547c9307a6e67..208e7bc69be76680868c766bc99429eea5870c80 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -10,25 +10,23 @@ exports_files(["LICENSE"]) package(default_visibility = ["//tensorflow:__subpackages__"]) -load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load("//tensorflow:tensorflow.bzl", "cuda_py_test") -cuda_py_tests( - name = "linear_operator_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", +py_library( + name = "linalg_py", + srcs = ["__init__.py"] + glob(["python/ops/*.py"]), + srcs_version = "PY2AND3", + deps = [ "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", + "//tensorflow/python:check_ops", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", + "//tensorflow/python:util", + "//tensorflow/python/ops/linalg", + "@six_archive//:six", ], ) -cuda_py_tests( +cuda_py_test( name = "linear_operator_addition_test", size = "small", srcs = ["python/kernel_tests/linear_operator_addition_test.py"], @@ -45,142 +43,6 @@ cuda_py_tests( ], ) -cuda_py_tests( - name = "linear_operator_composition_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_composition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], - tags = ["noasan"], # times out b/63678675 -) - -cuda_py_tests( - name = "linear_operator_diag_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_diag_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_identity_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_identity_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:platform_test", - "//tensorflow/python:random_ops", - ], -) - -cuda_py_tests( - name = "linear_operator_full_matrix_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_full_matrix_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_tril_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_tril_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], -) - -cuda_py_tests( - name = "linear_operator_udvh_update_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_udvh_update_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:platform_test", - ], - shard_count = 5, -) - -cuda_py_tests( - name = "linear_operator_util_test", - size = "medium", - srcs = ["python/kernel_tests/linear_operator_util_test.py"], - additional_deps = [ - ":linalg_py", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) - -py_library( - name = "linalg_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:common_shapes", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:linalg_ops", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform", - "//tensorflow/python:random_ops", - "//tensorflow/python:random_seed", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", - "//third_party/py/numpy", - "@six_archive//:six", - ], -) - filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 44421a6b7de0344a9a4a172ddc0900a44eb74450..4720692c3384ba1bede1f486c1b1e0e69d10a63a 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -21,8 +21,8 @@ See the @{$python/contrib.linalg} guide. @@LinearOperatorIdentity @@LinearOperatorScaledIdentity @@LinearOperatorFullMatrix -@@LinearOperatorTriL -@@LinearOperatorUDVHUpdate +@@LinearOperatorLowerTriangular +@@LinearOperatorLowRankUpdate @@LinearOperatorComposition @@add_operators @@ -33,14 +33,14 @@ from __future__ import print_function # pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member -from tensorflow.contrib.linalg.python.ops.linear_operator import * from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_composition import * -from tensorflow.contrib.linalg.python.ops.linear_operator_diag import * -from tensorflow.contrib.linalg.python.ops.linear_operator_full_matrix import * -from tensorflow.contrib.linalg.python.ops.linear_operator_identity import * -from tensorflow.contrib.linalg.python.ops.linear_operator_tril import * -from tensorflow.contrib.linalg.python.ops.linear_operator_udvh_update import * +from tensorflow.python.ops.linalg.linear_operator import * +from tensorflow.python.ops.linalg.linear_operator_composition import * +from tensorflow.python.ops.linalg.linear_operator_diag import * +from tensorflow.python.ops.linalg.linear_operator_full_matrix import * +from tensorflow.python.ops.linalg.linear_operator_identity import * +from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * +from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py index 474648475579d3f840d18ba3dd3291b90755a93a..6a72df6dfd8d8c35211bab42b240b83d77160a02 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py @@ -19,10 +19,10 @@ from __future__ import print_function import numpy as np -from tensorflow.contrib import linalg as linalg_lib from tensorflow.contrib.linalg.python.ops import linear_operator_addition from tensorflow.python.framework import random_seed from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib from tensorflow.python.platform import test linalg = linalg_lib @@ -114,7 +114,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): def test_diag_tril_diag(self): op1 = linalg.LinearOperatorDiag( [1., 1.], is_non_singular=True, name="diag_a") - op2 = linalg.LinearOperatorTriL( + op2 = linalg.LinearOperatorLowerTriangular( [[2., 0.], [0., 2.]], is_self_adjoint=True, is_non_singular=True, @@ -125,7 +125,7 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op_sum = add_operators([op1, op2, op3]) self.assertEqual(1, len(op_sum)) op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorTriL)) + self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular)) self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) # The diag operators will be self-adjoint (because real and diagonal). @@ -140,7 +140,8 @@ class LinearOperatorAdditionCorrectnessTest(test.TestCase): op0 = linalg.LinearOperatorFullMatrix( [[-1., -1.], [-1., -1.]], name="matrix") op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") - op2 = linalg.LinearOperatorTriL([[2., 0.], [1.5, 2.]], name="tril") + op2 = linalg.LinearOperatorLowerTriangular( + [[2., 0.], [1.5, 2.]], name="tril") op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") with self.test_session(): op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") @@ -189,7 +190,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_tier_1_additions_done_by_tier_1(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [linear_operator_addition._AddAndReturnTriL()], @@ -199,12 +200,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # _BadAdder) was never reached. op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnTriL()], [linear_operator_addition._AddAndReturnDiag()], @@ -216,12 +217,12 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): # Tier 2 was never used (therefore, _BadAdder didn't raise). op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) def test_cannot_add_everything_so_return_more_than_one_operator(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([2.]) - tril5 = linalg.LinearOperatorTriL([[5.]]) + tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], ] @@ -237,7 +238,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): if isinstance(op, linalg.LinearOperatorDiag): found_diag = True self.assertAllClose([[3.]], op.to_dense().eval()) - if isinstance(op, linalg.LinearOperatorTriL): + if isinstance(op, linalg.LinearOperatorLowerTriangular): found_tril = True self.assertAllClose([[5.]], op.to_dense().eval()) self.assertTrue(found_diag and found_tril) @@ -245,7 +246,7 @@ class LinearOperatorOrderOfAdditionTest(test.TestCase): def test_intermediate_tier_is_not_skipped(self): diag1 = linalg.LinearOperatorDiag([1.]) diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorTriL([[1.]]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) addition_tiers = [ [linear_operator_addition._AddAndReturnDiag()], [_BadAdder()], @@ -369,14 +370,14 @@ class AddAndReturnTriLTest(test.TestCase): def test_diag_plus_tril(self): diag = linalg.LinearOperatorDiag([1., 2.]) - tril = linalg.LinearOperatorTriL([[10., 0.], [30., 0.]]) + tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) hints = linear_operator_addition._Hints( is_positive_definite=True, is_non_singular=True) self.assertTrue(self._adder.can_add(diag, diag)) self.assertTrue(self._adder.can_add(diag, tril)) operator = self._adder.add(diag, tril, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorTriL)) + self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular)) with self.test_session(): self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py index 16c4c6e6d67f17d1674b8d1d39f006bc688bc6ce..86130a2c077ce14a7539b281ec809029bc05e071 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py @@ -22,14 +22,14 @@ import abc import six -from tensorflow.contrib.linalg.python.ops import linear_operator -from tensorflow.contrib.linalg.python.ops import linear_operator_diag -from tensorflow.contrib.linalg.python.ops import linear_operator_full_matrix -from tensorflow.contrib.linalg.python.ops import linear_operator_identity -from tensorflow.contrib.linalg.python.ops import linear_operator_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_full_matrix +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_lower_triangular __all__ = [] @@ -347,7 +347,7 @@ class _AddAndReturnTriL(_Adder): else: op_add_to_tensor, op_other = op2, op1 - return linear_operator_tril.LinearOperatorTriL( + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), is_non_singular=hints.is_non_singular, is_self_adjoint=hints.is_self_adjoint, @@ -397,7 +397,8 @@ def _type(operator): """Returns the type name constant (e.g. _TRIL) for operator.""" if isinstance(operator, linear_operator_diag.LinearOperatorDiag): return _DIAG - if isinstance(operator, linear_operator_tril.LinearOperatorTriL): + if isinstance(operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): return _TRIL if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): return _MATRIX diff --git a/tensorflow/contrib/lookup/BUILD b/tensorflow/contrib/lookup/BUILD index b8455477b0e39b54b6a5419ebd6ad41b2fc07912..b7b5418fe91e496f021b44fc32a33d2a549782e5 100644 --- a/tensorflow/contrib/lookup/BUILD +++ b/tensorflow/contrib/lookup/BUILD @@ -34,12 +34,12 @@ py_test( deps = [ ":lookup_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", "//tensorflow/python:lookup_ops", + "//tensorflow/python:session", "//tensorflow/python:sparse_tensor", "//tensorflow/python:training", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/losses/BUILD b/tensorflow/contrib/losses/BUILD index f75b0aa1b3e6606b0c92ae94b15b12781fe8b777..56942115213a762e532971a81da768b53b8537d8 100644 --- a/tensorflow/contrib/losses/BUILD +++ b/tensorflow/contrib/losses/BUILD @@ -15,15 +15,23 @@ py_library( "__init__.py", "python/losses/__init__.py", "python/losses/loss_ops.py", + "python/metric_learning/metric_loss_ops.py", ], srcs_version = "PY2AND3", deps = [ + ":metric_learning_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:nn", "//tensorflow/python:nn_ops", + "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", "//tensorflow/python:util", ], ) @@ -50,6 +58,46 @@ py_test( ], ) +py_library( + name = "metric_learning_py", + srcs = [ + "python/metric_learning/__init__.py", + "python/metric_learning/metric_loss_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:logging_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:script_ops", + "//tensorflow/python:sparse_ops", + "//tensorflow/python:sparse_tensor", + "//tensorflow/python:summary", + "//tensorflow/python:util", + ], +) + +py_test( + name = "metric_loss_ops_test", + size = "large", + srcs = [ + "python/metric_learning/metric_loss_ops_test.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":metric_learning_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:sparse_tensor", + "//third_party/py/numpy", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/losses/__init__.py b/tensorflow/contrib/losses/__init__.py index 790bf61367d85b79bae4b153328b229b10721b38..db58647d48f0f6f093ef4b71d1e8a7b79e611184 100644 --- a/tensorflow/contrib/losses/__init__.py +++ b/tensorflow/contrib/losses/__init__.py @@ -22,10 +22,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from tensorflow.contrib.losses.python import metric_learning # pylint: disable=wildcard-import from tensorflow.contrib.losses.python.losses import * # pylint: enable=wildcard-import - from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ @@ -43,5 +43,6 @@ _allowed_symbols = [ 'sigmoid_cross_entropy', 'softmax_cross_entropy', 'sparse_softmax_cross_entropy', + 'metric_learning' ] remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/losses/python/losses/loss_ops.py b/tensorflow/contrib/losses/python/losses/loss_ops.py index 1d2477b8b794240bd348cec7f626be794181ffb4..7c523ad49265aaf32c8d5a8ae04d3e93262a1b55 100644 --- a/tensorflow/contrib/losses/python/losses/loss_ops.py +++ b/tensorflow/contrib/losses/python/losses/loss_ops.py @@ -28,6 +28,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn from tensorflow.python.ops import nn_ops from tensorflow.python.util.deprecation import deprecated +from tensorflow.python.util.deprecation import deprecated_args __all__ = ["absolute_difference", "add_loss", @@ -623,8 +624,9 @@ def mean_pairwise_squared_error( @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.") +@deprecated_args(None, "dim is deprecated, use axis instead", "dim") def cosine_distance( - predictions, labels=None, dim=None, weights=1.0, scope=None): + predictions, labels=None, axis=None, weights=1.0, scope=None, dim=None): """Adds a cosine-distance loss to the training procedure. Note that the function assumes that `predictions` and `labels` are already @@ -633,10 +635,11 @@ def cosine_distance( Args: predictions: An arbitrary matrix. labels: A `Tensor` whose shape matches 'predictions' - dim: The dimension along which the cosine distance is computed. + axis: The dimension along which the cosine distance is computed. weights: Coefficients for the loss a scalar, a tensor of shape [batch_size] or a tensor whose shape matches `predictions`. scope: The scope for the operations performed in computing the loss. + dim: The old (deprecated) name for `axis`. Returns: A scalar `Tensor` representing the loss value. @@ -645,8 +648,12 @@ def cosine_distance( ValueError: If `predictions` shape doesn't match `labels` shape, or `weights` is `None`. """ - if dim is None: - raise ValueError("`dim` cannot be None.") + if dim is not None: + if axis is not None: + raise ValueError("Cannot specify both 'axis' and 'dim'") + axis = dim + if axis is None and dim is None: + raise ValueError("You must specify 'axis'.") with ops.name_scope(scope, "cosine_distance_loss", [predictions, labels, weights]) as scope: predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -655,5 +662,5 @@ def cosine_distance( labels = math_ops.to_float(labels) radial_diffs = math_ops.multiply(predictions, labels) - losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[dim,]) + losses = 1 - math_ops.reduce_sum(radial_diffs, reduction_indices=[axis,]) return compute_weighted_loss(losses, weights, scope=scope) diff --git a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py b/tensorflow/contrib/losses/python/metric_learning/__init__.py similarity index 63% rename from tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py rename to tensorflow/contrib/losses/python/metric_learning/__init__.py index 4d39a7918b36240f970aa192b907c3d127441657..4e551d6acafb5c565965503075e8416e01c20a71 100644 --- a/tensorflow/contrib/bayesflow/python/ops/stochastic_tensor.py +++ b/tensorflow/contrib/losses/python/metric_learning/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,37 +12,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Support for creating Stochastic Tensors. +"""Ops for building neural network losses. -See the @{$python/contrib.bayesflow.stochastic_tensor} guide. - -@@BaseStochasticTensor -@@StochasticTensor -@@MeanValue -@@SampleValue -@@value_type -@@get_current_value_type +See @{$python/contrib.losses}. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function -# go/tf-wildcard-import # pylint: disable=wildcard-import -from tensorflow.contrib.bayesflow.python.ops.stochastic_tensor_impl import * +from tensorflow.contrib.losses.python.metric_learning.metric_loss_ops import * # pylint: enable=wildcard-import from tensorflow.python.util.all_util import remove_undocumented - _allowed_symbols = [ - "BaseStochasticTensor", - "StochasticTensor", - "ObservedStochasticTensor", - "MeanValue", - "SampleValue", - "value_type", - "get_current_value_type", + 'contrastive_loss', + 'cluster_loss', + 'lifted_struct_loss', + 'npairs_loss', + 'npairs_loss_multilabel', + 'triplet_semihard_loss', ] - remove_undocumented(__name__, _allowed_symbols) + + diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a57ba51bcf0a292490dfaa9e556f6e5811ed66 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py @@ -0,0 +1,1031 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Implements various metric learning losses.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import logging_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import script_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.summary import summary +try: + # pylint: disable=g-import-not-at-top + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance(feature, squared=False): + """Computes the pairwise distance matrix with numerical stability. + + output[i, j] = || feature[i, :] - feature[j, :] ||_2 + + Args: + feature: 2-D Tensor of size [number of data, feature dimension]. + squared: Boolean, whether or not to square the pairwise distances. + + Returns: + pairwise_distances: 2-D Tensor of size [number of data, number of data]. + """ + pairwise_distances_squared = math_ops.add( + math_ops.reduce_sum( + math_ops.square(feature), + axis=[1], + keep_dims=True), + math_ops.reduce_sum( + math_ops.square( + array_ops.transpose(feature)), + axis=[0], + keep_dims=True)) - 2.0 * math_ops.matmul( + feature, array_ops.transpose(feature)) + + # Deal with numerical inaccuracies. Set small negatives to zero. + pairwise_distances_squared = math_ops.maximum(pairwise_distances_squared, 0.0) + # Get the mask where the zero distances are at. + error_mask = math_ops.less_equal(pairwise_distances_squared, 0.0) + + # Optionally take the sqrt. + if squared: + pairwise_distances = pairwise_distances_squared + else: + pairwise_distances = math_ops.sqrt( + pairwise_distances_squared + math_ops.to_float(error_mask) * 1e-16) + + # Undo conditionally adding 1e-16. + pairwise_distances = math_ops.multiply( + pairwise_distances, math_ops.to_float(math_ops.logical_not(error_mask))) + + num_data = array_ops.shape(feature)[0] + # Explicitly set diagonals to zero. + mask_offdiagonals = array_ops.ones_like(pairwise_distances) - array_ops.diag( + array_ops.ones([num_data])) + pairwise_distances = math_ops.multiply(pairwise_distances, mask_offdiagonals) + return pairwise_distances + + +def contrastive_loss(labels, embeddings_anchor, embeddings_positive, + margin=1.0): + """Computes the contrastive loss. + + This loss encourages the embedding to be close to each other for + the samples of the same label and the embedding to be far apart at least + by the margin constant for the samples of different labels. + See: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + binary labels indicating positive vs negative pair. + embeddings_anchor: 2-D float `Tensor` of embedding vectors for the anchor + images. Embeddings should be l2 normalized. + embeddings_positive: 2-D float `Tensor` of embedding vectors for the + positive images. Embeddings should be l2 normalized. + margin: margin term in the loss definition. + + Returns: + contrastive_loss: tf.float32 scalar. + """ + # Get per pair distances + distances = math_ops.sqrt( + math_ops.reduce_sum( + math_ops.square(embeddings_anchor - embeddings_positive), 1)) + + # Add contrastive loss for the siamese network. + # label here is {0,1} for neg, pos. + return math_ops.reduce_mean( + math_ops.to_float(labels) * math_ops.square(distances) + + (1. - math_ops.to_float(labels)) * + math_ops.square(math_ops.maximum(margin - distances, 0.)), + name='contrastive_loss') + + +def masked_maximum(data, mask, dim=1): + """Computes the axis wise maximum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the maximum. + + Returns: + masked_maximums: N-D `Tensor`. + The maximized dimension is of size 1 after the operation. + """ + axis_minimums = math_ops.reduce_min(data, dim, keep_dims=True) + masked_maximums = math_ops.reduce_max( + math_ops.multiply( + data - axis_minimums, mask), dim, keep_dims=True) + axis_minimums + return masked_maximums + + +def masked_minimum(data, mask, dim=1): + """Computes the axis wise minimum over chosen elements. + + Args: + data: 2-D float `Tensor` of size [n, m]. + mask: 2-D Boolean `Tensor` of size [n, m]. + dim: The dimension over which to compute the minimum. + + Returns: + masked_minimums: N-D `Tensor`. + The minimized dimension is of size 1 after the operation. + """ + axis_maximums = math_ops.reduce_max(data, dim, keep_dims=True) + masked_minimums = math_ops.reduce_min( + math_ops.multiply( + data - axis_maximums, mask), dim, keep_dims=True) + axis_maximums + return masked_minimums + + +def triplet_semihard_loss(labels, embeddings, margin=1.0): + """Computes the triplet loss with semi-hard negative mining. + + The loss encourages the positive distances (between a pair of embeddings with + the same labels) to be smaller than the minimum negative distance among + which are at least greater than the positive distance plus the margin constant + (called semi-hard negative) in the mini-batch. If no such negative exists, + uses the largest negative distance instead. + See: https://arxiv.org/abs/1503.03832. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + triplet_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pdist_matrix = pairwise_distance(embeddings, squared=True) + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + # Compute the mask. + pdist_matrix_tile = array_ops.tile(pdist_matrix, [batch_size, 1]) + mask = math_ops.logical_and( + array_ops.tile(adjacency_not, [batch_size, 1]), + math_ops.greater( + pdist_matrix_tile, array_ops.reshape( + array_ops.transpose(pdist_matrix), [-1, 1]))) + mask_final = array_ops.reshape( + math_ops.greater( + math_ops.reduce_sum( + math_ops.cast( + mask, dtype=dtypes.float32), 1, keep_dims=True), + 0.0), [batch_size, batch_size]) + mask_final = array_ops.transpose(mask_final) + + adjacency_not = math_ops.cast(adjacency_not, dtype=dtypes.float32) + mask = math_ops.cast(mask, dtype=dtypes.float32) + + # negatives_outside: smallest D_an where D_an > D_ap. + negatives_outside = array_ops.reshape( + masked_minimum(pdist_matrix_tile, mask), [batch_size, batch_size]) + negatives_outside = array_ops.transpose(negatives_outside) + + # negatives_inside: largest D_an. + negatives_inside = array_ops.tile( + masked_maximum(pdist_matrix, adjacency_not), [1, batch_size]) + semi_hard_negatives = array_ops.where( + mask_final, negatives_outside, negatives_inside) + + loss_mat = math_ops.add(margin, pdist_matrix - semi_hard_negatives) + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # In lifted-struct, the authors multiply 0.5 for upper triangular + # in semihard, they take all positive pairs except the diagonal. + num_positives = math_ops.reduce_sum(mask_positives) + + triplet_loss = math_ops.truediv( + math_ops.reduce_sum( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0)), + num_positives, + name='triplet_semihard_loss') + + return triplet_loss + + +# pylint: disable=line-too-long +def npairs_loss(labels, embeddings_anchor, embeddings_positive, + reg_lambda=0.002, print_losses=False): + """Computes the npairs loss. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. + + See: http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf + + Args: + labels: 1-D tf.int32 `Tensor` of shape [batch_size/2]. + embeddings_anchor: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D Tensor of shape [batch_size/2, embedding_dim] for the + embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + """ + # pylint: enable=line-too-long + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply( + 0.25 * reg_lambda, reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + labels_remapped = math_ops.to_float( + math_ops.equal(labels, array_ops.transpose(labels))) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def _build_multilabel_adjacency(sparse_labels): + """Builds multilabel adjacency matrix. + + As of March 14th, 2017, there's no op for the dot product between + two sparse tensors in TF. However, there is `sparse_minimum` op which is + equivalent to an AND op between two sparse boolean tensors. + This computes the dot product between two sparse boolean inputs. + + Args: + sparse_labels: List of 1-D boolean sparse tensors. + + Returns: + adjacency_matrix: 2-D dense `Tensor`. + """ + num_pairs = len(sparse_labels) + adjacency_matrix = array_ops.zeros([num_pairs, num_pairs]) + for i in range(num_pairs): + for j in range(num_pairs): + sparse_dot_product = math_ops.to_float( + sparse_ops.sparse_reduce_sum(sparse_ops.sparse_minimum( + sparse_labels[i], sparse_labels[j]))) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 0) + sparse_dot_product = array_ops.expand_dims(sparse_dot_product, 1) + one_hot_matrix = array_ops.pad(sparse_dot_product, + [[i, num_pairs-i-1], + [j, num_pairs-j-1]], 'CONSTANT') + adjacency_matrix += one_hot_matrix + + return adjacency_matrix + + +def npairs_loss_multilabel(sparse_labels, embeddings_anchor, + embeddings_positive, reg_lambda=0.002, + print_losses=False): + r"""Computes the npairs loss with multilabel data. + + Npairs loss expects paired data where a pair is composed of samples from the + same labels and each pairs in the minibatch have different labels. The loss + has two components. The first component is the L2 regularizer on the + embedding vectors. The second component is the sum of cross entropy loss + which takes each row of the pair-wise similarity matrix as logits and + the remapped one-hot labels as labels. Here, the similarity is defined by the + dot product between two embedding vectors. S_{i,j} = f(x_i)^T f(x_j) + + To deal with multilabel inputs, we use the count of label intersection + i.e. L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | + Then we normalize each rows of the count based label matrix so that each row + sums to one. + + Args: + sparse_labels: List of 1-D Boolean `SparseTensor` of dense_shape + [batch_size/2, num_classes] labels for the anchor-pos pairs. + embeddings_anchor: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the anchor images. Embeddings should not be + l2 normalized. + embeddings_positive: 2-D `Tensor` of shape [batch_size/2, embedding_dim] for + the embedding vectors for the positive images. Embeddings should not be + l2 normalized. + reg_lambda: Float. L2 regularization term on the embedding vectors. + print_losses: Boolean. Option to print the xent and l2loss. + + Returns: + npairs_loss: tf.float32 scalar. + Raises: + TypeError: When the specified sparse_labels is not a `SparseTensor`. + """ + if False in [isinstance( + l, sparse_tensor.SparseTensor) for l in sparse_labels]: + raise TypeError( + 'sparse_labels must be a list of SparseTensors, but got %s' % str( + sparse_labels)) + + with ops.name_scope('NpairsLossMultiLabel'): + # Add the regularizer on the embedding. + reg_anchor = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_anchor), 1)) + reg_positive = math_ops.reduce_mean( + math_ops.reduce_sum(math_ops.square(embeddings_positive), 1)) + l2loss = math_ops.multiply(0.25 * reg_lambda, + reg_anchor + reg_positive, name='l2loss') + + # Get per pair similarities. + similarity_matrix = math_ops.matmul( + embeddings_anchor, embeddings_positive, transpose_a=False, + transpose_b=True) + + # TODO(coreylynch): need to check the sparse values + # TODO(coreylynch): are composed only of 0's and 1's. + + multilabel_adjacency_matrix = _build_multilabel_adjacency(sparse_labels) + labels_remapped = math_ops.to_float(multilabel_adjacency_matrix) + labels_remapped /= math_ops.reduce_sum(labels_remapped, 1, keep_dims=True) + + # Add the softmax loss. + xent_loss = nn.softmax_cross_entropy_with_logits( + logits=similarity_matrix, labels=labels_remapped) + xent_loss = math_ops.reduce_mean(xent_loss, name='xentropy') + + if print_losses: + xent_loss = logging_ops.Print( + xent_loss, ['cross entropy:', xent_loss, 'l2loss:', l2loss]) + + return l2loss + xent_loss + + +def lifted_struct_loss(labels, embeddings, margin=1.0): + """Computes the lifted structured loss. + + The loss encourages the positive distances (between a pair of embeddings + with the same labels) to be smaller than any negative distances (between a + pair of embeddings with different labels) in the mini-batch in a way + that is differentiable with respect to the embedding vectors. + See: https://arxiv.org/abs/1511.06452. + + Args: + labels: 1-D tf.int32 `Tensor` with shape [batch_size] of + multiclass integer labels. + embeddings: 2-D float `Tensor` of embedding vectors. Embeddings should not + be l2 normalized. + margin: Float, margin term in the loss definition. + + Returns: + lifted_loss: tf.float32 scalar. + """ + # Reshape [batch_size] label tensor to a [batch_size, 1] label tensor. + lshape = array_ops.shape(labels) + assert lshape.shape == 1 + labels = array_ops.reshape(labels, [lshape[0], 1]) + + # Build pairwise squared distance matrix. + pairwise_distances = pairwise_distance(embeddings) + + # Build pairwise binary adjacency matrix. + adjacency = math_ops.equal(labels, array_ops.transpose(labels)) + # Invert so we can select negatives only. + adjacency_not = math_ops.logical_not(adjacency) + + batch_size = array_ops.size(labels) + + diff = margin - pairwise_distances + mask = math_ops.cast(adjacency_not, dtype=dtypes.float32) + # Safe maximum: Temporarily shift negative distances + # above zero before taking max. + # this is to take the max only among negatives. + row_minimums = math_ops.reduce_min(diff, 1, keep_dims=True) + row_negative_maximums = math_ops.reduce_max( + math_ops.multiply( + diff - row_minimums, mask), 1, keep_dims=True) + row_minimums + + # Compute the loss. + # Keep track of matrix of maximums where M_ij = max(m_i, m_j) + # where m_i is the max of alpha - negative D_i's. + # This matches the Caffe loss layer implementation at: + # https://github.com/rksltnl/Caffe-Deep-Metric-Learning-CVPR16/blob/0efd7544a9846f58df923c8b992198ba5c355454/src/caffe/layers/lifted_struct_similarity_softmax_layer.cpp # pylint: disable=line-too-long + + max_elements = math_ops.maximum( + row_negative_maximums, array_ops.transpose(row_negative_maximums)) + diff_tiled = array_ops.tile(diff, [batch_size, 1]) + mask_tiled = array_ops.tile(mask, [batch_size, 1]) + max_elements_vect = array_ops.reshape( + array_ops.transpose(max_elements), [-1, 1]) + + loss_exp_left = array_ops.reshape( + math_ops.reduce_sum(math_ops.multiply( + math_ops.exp( + diff_tiled - max_elements_vect), + mask_tiled), 1, keep_dims=True), [batch_size, batch_size]) + + loss_mat = max_elements + math_ops.log( + loss_exp_left + array_ops.transpose(loss_exp_left)) + # Add the positive distance. + loss_mat += pairwise_distances + + mask_positives = math_ops.cast( + adjacency, dtype=dtypes.float32) - array_ops.diag( + array_ops.ones([batch_size])) + + # *0.5 for upper triangular, and another *0.5 for 1/2 factor for loss^2. + num_positives = math_ops.reduce_sum(mask_positives) / 2.0 + + lifted_loss = math_ops.truediv( + 0.25 * math_ops.reduce_sum( + math_ops.square( + math_ops.maximum( + math_ops.multiply(loss_mat, mask_positives), 0.0))), + num_positives, + name='liftedstruct_loss') + return lifted_loss + + +def update_1d_tensor(y, index, value): + """Updates 1d tensor y so that y[index] = value. + + Args: + y: 1-D Tensor. + index: index of y to modify. + value: new value to write at y[index]. + + Returns: + y_mod: 1-D Tensor. Tensor y after the update. + """ + value = array_ops.squeeze(value) + # modify the 1D tensor x at index with value. + # ex) chosen_ids = update_1D_tensor(chosen_ids, cluster_idx, best_medoid) + y_before = array_ops.slice(y, [0], [index]) + y_after = array_ops.slice(y, [index + 1], [-1]) + y_mod = array_ops.concat([y_before, [value], y_after], 0) + return y_mod + + +def get_cluster_assignment(pairwise_distances, centroid_ids): + """Assign data points to the neareset centroids. + + Tensorflow has numerical instability and doesn't always choose + the data point with theoretically zero distance as it's nearest neighbor. + Thus, for each centroid in centroid_ids, explicitly assign + the centroid itself as the nearest centroid. + This is done through the mask tensor and the constraint_vect tensor. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of centroid indices. + + Returns: + y_fixed: 1-D tensor of cluster assignment. + """ + predictions = math_ops.argmin( + array_ops.gather(pairwise_distances, centroid_ids), dimension=0) + batch_size = array_ops.shape(pairwise_distances)[0] + + # Deal with numerical instability + mask = math_ops.reduce_any(array_ops.one_hot( + centroid_ids, batch_size, True, False, axis=-1, dtype=dtypes.bool), + axis=0) + constraint_one_hot = math_ops.multiply( + array_ops.one_hot(centroid_ids, + batch_size, + array_ops.constant(1, dtype=dtypes.int64), + array_ops.constant(0, dtype=dtypes.int64), + axis=0, + dtype=dtypes.int64), + math_ops.to_int64(math_ops.range(array_ops.shape(centroid_ids)[0]))) + constraint_vect = math_ops.reduce_sum( + array_ops.transpose(constraint_one_hot), axis=0) + + y_fixed = array_ops.where(mask, constraint_vect, predictions) + return y_fixed + + +def compute_facility_energy(pairwise_distances, centroid_ids): + """Compute the average travel distance to the assigned centroid. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + centroid_ids: 1-D Tensor of indices. + + Returns: + facility_energy: dtypes.float32 scalar. + """ + return -1.0 * math_ops.reduce_sum( + math_ops.reduce_min( + array_ops.gather(pairwise_distances, centroid_ids), axis=0)) + + +def compute_clustering_score(labels, predictions, margin_type): + """Computes the clustering score via sklearn.metrics functions. + + There are various ways to compute the clustering score. Intuitively, + we want to measure the agreement of two clustering assignments (labels vs + predictions) ignoring the permutations and output a score from zero to one. + (where the values close to one indicate significant agreement). + This code supports following scoring functions: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + See http://scikit-learn.org/stable/modules/classes.html#clustering-metrics + for the detailed descriptions. + Args: + labels: 1-D Tensor. ground truth cluster assignment. + predictions: 1-D Tensor. predicted cluster assignment. + margin_type: Type of structured margin to use. Default is nmi. + Returns: + clustering_score: dtypes.float32 scalar. + The possible valid values are from zero to one. + Zero means the worst clustering and one means the perfect clustering. + Raises: + ValueError: margin_type is not recognized. + """ + margin_type_to_func = { + 'nmi': _compute_nmi_score, + 'ami': _compute_ami_score, + 'ari': _compute_ari_score, + 'vmeasure': _compute_vmeasure_score, + 'const': _compute_zeroone_score + } + + if margin_type not in margin_type_to_func: + raise ValueError('Unrecognized margin_type: %s' % margin_type) + clustering_score_fn = margin_type_to_func[margin_type] + return array_ops.squeeze(clustering_score_fn(labels, predictions)) + + +def _compute_nmi_score(labels, predictions): + return math_ops.to_float( + script_ops.py_func( + metrics.normalized_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='nmi')) + + +def _compute_ami_score(labels, predictions): + ami_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_mutual_info_score, [labels, predictions], + [dtypes.float64], + name='ami')) + return math_ops.maximum(0.0, ami_score) + + +def _compute_ari_score(labels, predictions): + ari_score = math_ops.to_float( + script_ops.py_func( + metrics.adjusted_rand_score, [labels, predictions], [dtypes.float64], + name='ari')) + # ari score can go below 0 + # http://scikit-learn.org/stable/modules/clustering.html#adjusted-rand-score + return math_ops.maximum(0.0, ari_score) + + +def _compute_vmeasure_score(labels, predictions): + vmeasure_score = math_ops.to_float( + script_ops.py_func( + metrics.v_measure_score, [labels, predictions], [dtypes.float64], + name='vmeasure')) + return math_ops.maximum(0.0, vmeasure_score) + + +def _compute_zeroone_score(labels, predictions): + zeroone_score = math_ops.to_float( + math_ops.equal( + math_ops.reduce_sum( + math_ops.to_int32(math_ops.equal(labels, predictions))), + array_ops.shape(labels)[0])) + return zeroone_score + + +def _find_loss_augmented_facility_idx(pairwise_distances, labels, chosen_ids, + candidate_ids, margin_multiplier, + margin_type): + """Find the next centroid that maximizes the loss augmented inference. + + This function is a subroutine called from compute_augmented_facility_locations + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of current centroid indices. + candidate_ids: 1-D Tensor of candidate indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + integer index. + """ + num_candidates = array_ops.shape(candidate_ids)[0] + + pairwise_distances_chosen = array_ops.gather(pairwise_distances, chosen_ids) + pairwise_distances_candidate = array_ops.gather( + pairwise_distances, candidate_ids) + pairwise_distances_chosen_tile = array_ops.tile( + pairwise_distances_chosen, [1, num_candidates]) + + candidate_scores = -1.0 * math_ops.reduce_sum( + array_ops.reshape( + math_ops.reduce_min( + array_ops.concat([ + pairwise_distances_chosen_tile, + array_ops.reshape(pairwise_distances_candidate, [1, -1]) + ], 0), + axis=0, + keep_dims=True), [num_candidates, -1]), + axis=1) + + nmi_scores = array_ops.zeros([num_candidates]) + iteration = array_ops.constant(0) + + def func_cond(iteration, nmi_scores): + del nmi_scores # Unused in func_cond() + return iteration < num_candidates + + def func_body(iteration, nmi_scores): + predictions = get_cluster_assignment( + pairwise_distances, + array_ops.concat([chosen_ids, [candidate_ids[iteration]]], 0)) + nmi_score_i = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + # return 1 - NMI score as the structured loss. + # because NMI is higher the better [0,1]. + return iteration + 1, nmi_scores + array_ops.concat( + [pad_before, [1.0 - nmi_score_i], pad_after], 0) + + _, nmi_scores = control_flow_ops.while_loop( + func_cond, func_body, [iteration, nmi_scores]) + + candidate_scores = math_ops.add( + candidate_scores, margin_multiplier * nmi_scores) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + return candidate_ids[argmax_index] + + +def compute_augmented_facility_locations(pairwise_distances, labels, all_ids, + margin_multiplier, margin_type): + """Computes the centroid locations. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + all_ids: 1-D Tensor of all data indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: 1-D Tensor of chosen centroid indices. + """ + + def func_cond_augmented(iteration, chosen_ids): + del chosen_ids # Unused argument in func_cond_augmented. + return iteration < num_classes + + def func_body_augmented(iteration, chosen_ids): + # find a new facility location to add + # based on the clustering score and the NMI score + candidate_ids = array_ops.setdiff1d(all_ids, chosen_ids)[0] + new_chosen_idx = _find_loss_augmented_facility_idx(pairwise_distances, + labels, chosen_ids, + candidate_ids, + margin_multiplier, + margin_type) + chosen_ids = array_ops.concat([chosen_ids, [new_chosen_idx]], 0) + return iteration + 1, chosen_ids + + num_classes = array_ops.size(array_ops.unique(labels)[0]) + chosen_ids = array_ops.constant(0, dtype=dtypes.int32, shape=[0]) + + # num_classes get determined at run time based on the sampled batch. + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented, + func_body_augmented, [iteration, chosen_ids], + shape_invariants=[iteration.get_shape(), tensor_shape.TensorShape( + [None])]) + return chosen_ids + + +def update_medoid_per_cluster(pairwise_distances, pairwise_distances_subset, + labels, chosen_ids, cluster_member_ids, + cluster_idx, margin_multiplier, margin_type): + """Updates the cluster medoid per cluster. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + pairwise_distances_subset: 2-D Tensor of pairwise distances for one cluster. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + cluster_member_ids: 1-D Tensor of cluster member indices for one cluster. + cluster_idx: Index of this one cluster. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond(iteration, scores_margin): + del scores_margin # Unused variable scores_margin. + return iteration < num_candidates + + def func_body(iteration, scores_margin): + # swap the current medoid with the candidate cluster member + candidate_medoid = math_ops.to_int32(cluster_member_ids[iteration]) + tmp_chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, candidate_medoid) + predictions = get_cluster_assignment(pairwise_distances, tmp_chosen_ids) + metric_score = compute_clustering_score(labels, predictions, margin_type) + pad_before = array_ops.zeros([iteration]) + pad_after = array_ops.zeros([num_candidates - 1 - iteration]) + return iteration + 1, scores_margin + array_ops.concat( + [pad_before, [1.0 - metric_score], pad_after], 0) + + # pairwise_distances_subset is of size [p, 1, 1, p], + # the intermediate dummy dimensions at + # [1, 2] makes this code work in the edge case where p=1. + # this happens if the cluster size is one. + scores_fac = -1.0 * math_ops.reduce_sum( + array_ops.squeeze(pairwise_distances_subset, [1, 2]), axis=0) + + iteration = array_ops.constant(0) + num_candidates = array_ops.size(cluster_member_ids) + scores_margin = array_ops.zeros([num_candidates]) + + _, scores_margin = control_flow_ops.while_loop(func_cond, func_body, + [iteration, scores_margin]) + candidate_scores = math_ops.add(scores_fac, margin_multiplier * scores_margin) + + argmax_index = math_ops.to_int32( + math_ops.argmax(candidate_scores, dimension=0)) + + best_medoid = math_ops.to_int32(cluster_member_ids[argmax_index]) + chosen_ids = update_1d_tensor(chosen_ids, cluster_idx, best_medoid) + return chosen_ids + + +def update_all_medoids(pairwise_distances, predictions, labels, chosen_ids, + margin_multiplier, margin_type): + """Updates all cluster medoids a cluster at a time. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + predictions: 1-D Tensor of predicted cluster assignment. + labels: 1-D Tensor of ground truth cluster assignment. + chosen_ids: 1-D Tensor of cluster centroid indices. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + + def func_cond_augmented_pam(iteration, chosen_ids): + del chosen_ids # Unused argument. + return iteration < num_classes + + def func_body_augmented_pam(iteration, chosen_ids): + """Call the update_medoid_per_cluster subroutine.""" + mask = math_ops.equal( + math_ops.to_int64(predictions), math_ops.to_int64(iteration)) + this_cluster_ids = array_ops.where(mask) + + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + + chosen_ids = update_medoid_per_cluster(pairwise_distances, + pairwise_distances_subset, labels, + chosen_ids, this_cluster_ids, + iteration, margin_multiplier, + margin_type) + return iteration + 1, chosen_ids + + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + + _, chosen_ids = control_flow_ops.while_loop( + func_cond_augmented_pam, func_body_augmented_pam, [iteration, chosen_ids]) + return chosen_ids + + +def compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids, + pam_max_iter=5): + """Refine the cluster centroids with PAM local search. + + For fixed iterations, alternate between updating the cluster assignment + and updating cluster medoids. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + margin_multiplier: multiplication constant. + margin_type: Type of structured margin to use. Default is nmi. + chosen_ids: 1-D Tensor of initial estimate of cluster centroids. + pam_max_iter: Number of refinement iterations. + + Returns: + chosen_ids: Updated 1-D Tensor of cluster centroid indices. + """ + for _ in range(pam_max_iter): + # update the cluster assignment given the chosen_ids (S_pred) + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # update the medoids per each cluster + chosen_ids = update_all_medoids(pairwise_distances, predictions, labels, + chosen_ids, margin_multiplier, margin_type) + + return chosen_ids + + +def compute_gt_cluster_score(pairwise_distances, labels): + """Compute ground truth facility location score. + + Loop over each unique classes and compute average travel distances. + + Args: + pairwise_distances: 2-D Tensor of pairwise distances. + labels: 1-D Tensor of ground truth cluster assignment. + + Returns: + gt_cluster_score: dtypes.float32 score. + """ + unique_class_ids = array_ops.unique(labels)[0] + num_classes = array_ops.size(unique_class_ids) + iteration = array_ops.constant(0) + gt_cluster_score = array_ops.constant(0.0, dtype=dtypes.float32) + + def func_cond(iteration, gt_cluster_score): + del gt_cluster_score # Unused argument. + return iteration < num_classes + + def func_body(iteration, gt_cluster_score): + """Per each cluster, compute the average travel distance.""" + mask = math_ops.equal(labels, unique_class_ids[iteration]) + this_cluster_ids = array_ops.where(mask) + pairwise_distances_subset = array_ops.transpose( + array_ops.gather( + array_ops.transpose( + array_ops.gather(pairwise_distances, this_cluster_ids)), + this_cluster_ids)) + this_cluster_score = -1.0 * math_ops.reduce_min( + math_ops.reduce_sum( + pairwise_distances_subset, axis=0)) + return iteration + 1, gt_cluster_score + this_cluster_score + + _, gt_cluster_score = control_flow_ops.while_loop( + func_cond, func_body, [iteration, gt_cluster_score]) + return gt_cluster_score + + +def cluster_loss(labels, + embeddings, + margin_multiplier, + enable_pam_finetuning=True, + margin_type='nmi', + print_losses=False): + """Computes the clustering loss. + + The following structured margins are supported: + nmi: normalized mutual information + ami: adjusted mutual information + ari: adjusted random index + vmeasure: v-measure + const: indicator checking whether the two clusterings are the same. + + Args: + labels: 2-D Tensor of labels of shape [batch size, 1] + embeddings: 2-D Tensor of embeddings of shape + [batch size, embedding dimension]. Embeddings should be l2 normalized. + margin_multiplier: float32 scalar. multiplier on the structured margin term + See section 3.2 of paper for discussion. + enable_pam_finetuning: Boolean, Whether to run local pam refinement. + See section 3.4 of paper for discussion. + margin_type: Type of structured margin to use. See section 3.2 of + paper for discussion. Can be 'nmi', 'ami', 'ari', 'vmeasure', 'const'. + print_losses: Boolean. Option to print the loss. + + Paper: https://arxiv.org/abs/1612.01213. + + Returns: + clustering_loss: A float32 scalar `Tensor`. + Raises: + ImportError: If sklearn dependency is not installed. + """ + if not HAS_SKLEARN: + raise ImportError('Cluster loss depends on sklearn.') + pairwise_distances = pairwise_distance(embeddings) + labels = array_ops.squeeze(labels) + all_ids = math_ops.range(array_ops.shape(embeddings)[0]) + + # Compute the loss augmented inference and get the cluster centroids. + chosen_ids = compute_augmented_facility_locations(pairwise_distances, labels, + all_ids, margin_multiplier, + margin_type) + # Given the predicted centroids, compute the clustering score. + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Branch whether to use PAM finetuning. + if enable_pam_finetuning: + # Initialize with augmented facility solution. + chosen_ids = compute_augmented_facility_locations_pam(pairwise_distances, + labels, + margin_multiplier, + margin_type, + chosen_ids) + score_pred = compute_facility_energy(pairwise_distances, chosen_ids) + + # Given the predicted centroids, compute the cluster assignments. + predictions = get_cluster_assignment(pairwise_distances, chosen_ids) + + # Compute the clustering (i.e. NMI) score between the two assignments. + clustering_score_pred = compute_clustering_score(labels, predictions, + margin_type) + + # Compute the clustering score from labels. + score_gt = compute_gt_cluster_score(pairwise_distances, labels) + + # Compute the hinge loss. + clustering_loss = math_ops.maximum( + score_pred + margin_multiplier * (1.0 - clustering_score_pred) - score_gt, + 0.0, + name='clustering_loss') + clustering_loss.set_shape([]) + + if print_losses: + clustering_loss = logging_ops.Print( + clustering_loss, + ['clustering_loss: ', clustering_loss, array_ops.shape( + clustering_loss)]) + + # Clustering specific summary. + summary.scalar('losses/score_pred', score_pred) + summary.scalar('losses/' + margin_type, clustering_score_pred) + summary.scalar('losses/score_gt', score_gt) + + return clustering_loss diff --git a/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..4ec539ab42b4e0ba90a2a1f379a1d4d4b49d11f3 --- /dev/null +++ b/tensorflow/contrib/losses/python/metric_learning/metric_loss_ops_test.py @@ -0,0 +1,562 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for triplet_semihard_loss.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from tensorflow.contrib.losses.python import metric_learning as metric_loss_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.platform import test +try: + # pylint: disable=g-import-not-at-top + from sklearn import datasets + from sklearn import metrics + HAS_SKLEARN = True +except ImportError: + HAS_SKLEARN = False + + +def pairwise_distance_np(feature, squared=False): + """Computes the pairwise distance matrix in numpy. + + Args: + feature: 2-D numpy array of size [number of data, feature dimension] + squared: Boolean. If true, output is the pairwise squared euclidean + distance matrix; else, output is the pairwise euclidean distance matrix. + + Returns: + pairwise_distances: 2-D numpy array of size + [number of data, number of data]. + """ + triu = np.triu_indices(feature.shape[0], 1) + upper_tri_pdists = np.linalg.norm(feature[triu[1]] - feature[triu[0]], axis=1) + if squared: + upper_tri_pdists **= 2. + num_data = feature.shape[0] + pairwise_distances = np.zeros((num_data, num_data)) + pairwise_distances[np.triu_indices(num_data, 1)] = upper_tri_pdists + # Make symmetrical. + pairwise_distances = pairwise_distances + pairwise_distances.T - np.diag( + pairwise_distances.diagonal()) + return pairwise_distances + + +class ContrastiveLossTest(test.TestCase): + + def testContrastive(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.random.randint(0, 2, size=(num_data,)).astype(np.float32) + + # Compute the loss in NP + dist = np.sqrt( + np.sum(np.square(embeddings_anchor - embeddings_positive), axis=1)) + loss_np = np.mean( + labels * np.square(dist) + + (1.0 - labels) * np.square(np.maximum(margin - dist, 0.0))) + # Compute the loss with TF + loss_tf = metric_loss_ops.contrastive_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class TripletSemiHardLossTest(test.TestCase): + + def testTripletSemiHard(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + # Compute the loss in NP. + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + + pdist_matrix = pairwise_distance_np(embedding, squared=True) + loss_np = 0.0 + num_positives = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + num_positives += 1.0 + + pos_distance = pdist_matrix[i][j] + neg_distances = [] + + for k in range(num_data): + if adjacency[i][k] == 0: + neg_distances.append(pdist_matrix[i][k]) + + # Sort by distance. + neg_distances.sort() + chosen_neg_distance = neg_distances[0] + + for l in range(len(neg_distances)): + chosen_neg_distance = neg_distances[l] + if chosen_neg_distance > pos_distance: + break + + loss_np += np.maximum( + 0.0, margin - chosen_neg_distance + pos_distance) + + loss_np /= num_positives + + # Compute the loss in TF. + loss_tf = metric_loss_ops.triplet_semihard_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class LiftedStructLossTest(test.TestCase): + + def testLiftedStruct(self): + with self.test_session(): + num_data = 10 + feat_dim = 6 + margin = 1.0 + num_classes = 4 + + embedding = np.random.rand(num_data, feat_dim).astype(np.float32) + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + adjacency = np.equal(labels_reshaped, labels_reshaped.T) + pdist_matrix = pairwise_distance_np(embedding) + loss_np = 0.0 + num_constraints = 0.0 + for i in range(num_data): + for j in range(num_data): + if adjacency[i][j] > 0.0 and i != j: + d_pos = pdist_matrix[i][j] + negs = [] + for k in range(num_data): + if not adjacency[i][k]: + negs.append(margin - pdist_matrix[i][k]) + for l in range(num_data): + if not adjacency[j][l]: + negs.append(margin - pdist_matrix[j][l]) + + negs = np.array(negs) + max_elem = np.max(negs) + negs -= max_elem + negs = np.exp(negs) + soft_maximum = np.log(np.sum(negs)) + max_elem + + num_constraints += 1.0 + this_loss = max(soft_maximum + d_pos, 0) + loss_np += this_loss * this_loss + + loss_np = loss_np / num_constraints / 2.0 + + # Compute the loss in TF + loss_tf = metric_loss_ops.lifted_struct_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embedding), + margin=margin) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +def convert_to_list_of_sparse_tensor(np_matrix): + list_of_sparse_tensors = [] + nrows, ncols = np_matrix.shape + for i in range(nrows): + sp_indices = [] + for j in range(ncols): + if np_matrix[i][j] == 1: + sp_indices.append([j]) + + num_non_zeros = len(sp_indices) + list_of_sparse_tensors.append(sparse_tensor.SparseTensor( + indices=np.array(sp_indices), + values=np.ones((num_non_zeros,)), + dense_shape=np.array([ncols,]))) + + return list_of_sparse_tensors + + +class NpairsLossTest(test.TestCase): + + def testNpairs(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 5 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint( + 0, num_classes, size=(num_data)).astype(np.float32) + # Reshape labels to compute adjacency matrix. + labels_reshaped = np.reshape(labels, (labels.shape[0], 1)) + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.equal( + labels_reshaped, labels_reshaped.T).astype(np.float32) + labels_remapped /= np.sum(labels_remapped, axis=1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + +class NpairsLossMultiLabelTest(test.TestCase): + + def testNpairsMultiLabelLossWithSingleLabelEqualsNpairsLoss(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + labels = np.arange(num_data) + labels = np.reshape(labels, -1) + + # Compute vanila npairs loss. + loss_npairs = metric_loss_ops.npairs_loss( + labels=ops.convert_to_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + # Compute npairs multilabel loss. + labels_one_hot = np.identity(num_data) + loss_npairs_multilabel = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels_one_hot), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda).eval() + + self.assertAllClose(loss_npairs, loss_npairs_multilabel) + + def testNpairsMultiLabel(self): + with self.test_session(): + num_data = 15 + feat_dim = 6 + num_classes = 10 + reg_lambda = 0.02 + + embeddings_anchor = np.random.rand(num_data, feat_dim).astype(np.float32) + embeddings_positive = np.random.rand(num_data, feat_dim).astype( + np.float32) + + labels = np.random.randint(0, 2, (num_data, num_classes)) + # set entire column to one so that each row has at least one bit set. + labels[:, -1] = 1 + + # Compute the loss in NP + reg_term = np.mean(np.sum(np.square(embeddings_anchor), 1)) + reg_term += np.mean(np.sum(np.square(embeddings_positive), 1)) + reg_term *= 0.25 * reg_lambda + + similarity_matrix = np.matmul(embeddings_anchor, embeddings_positive.T) + + labels_remapped = np.dot(labels, labels.T).astype(np.float) + labels_remapped /= np.sum(labels_remapped, 1, keepdims=True) + + xent_loss = math_ops.reduce_mean(nn.softmax_cross_entropy_with_logits( + logits=ops.convert_to_tensor(similarity_matrix), + labels=ops.convert_to_tensor(labels_remapped))).eval() + loss_np = xent_loss + reg_term + + # Compute the loss in TF + loss_tf = metric_loss_ops.npairs_loss_multilabel( + sparse_labels=convert_to_list_of_sparse_tensor(labels), + embeddings_anchor=ops.convert_to_tensor(embeddings_anchor), + embeddings_positive=ops.convert_to_tensor(embeddings_positive), + reg_lambda=reg_lambda) + loss_tf = loss_tf.eval() + + self.assertAllClose(loss_np, loss_tf) + + +def compute_ground_truth_cluster_score(feat, y): + y_unique = np.unique(y) + score_gt_np = 0.0 + for c in y_unique: + feat_subset = feat[y == c, :] + pdist_subset = pairwise_distance_np(feat_subset) + score_gt_np += -1.0 * np.min(np.sum(pdist_subset, axis=0)) + score_gt_np = score_gt_np.astype(np.float32) + return score_gt_np + + +def compute_cluster_loss_numpy(feat, + y, + margin_multiplier=1.0, + enable_pam_finetuning=True): + if enable_pam_finetuning: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).pam_augmented_fit(feat, y, + margin_multiplier) + else: + facility = ForwardGreedyFacility( + n_clusters=np.unique(y).size).loss_augmented_fit(feat, y, + margin_multiplier) + + score_augmented = facility.score_aug_ + score_gt = compute_ground_truth_cluster_score(feat, y) + return np.maximum(np.float32(0.0), score_augmented - score_gt) + + +class ForwardGreedyFacility(object): + + def __init__(self, n_clusters=8): + self.n_clusters = n_clusters + self.center_ics_ = None + + def _check_init_args(self): + # Check n_clusters. + if (self.n_clusters is None or self.n_clusters <= 0 or + not isinstance(self.n_clusters, int)): + raise ValueError('n_clusters has to be nonnegative integer.') + + def loss_augmented_fit(self, feat, y, loss_mult): + """Fit K-Medoids to the provided data.""" + self._check_init_args() + # Check that the array is good and attempt to convert it to + # Numpy array if possible. + feat = self._check_array(feat) + # Apply distance metric to get the distance matrix. + pdists = pairwise_distance_np(feat) + + num_data = feat.shape[0] + candidate_ids = list(range(num_data)) + candidate_scores = np.zeros(num_data,) + subset = [] + + k = 0 + while k < self.n_clusters: + candidate_scores = [] + for i in candidate_ids: + # push i to subset. + subset.append(i) + marginal_cost = -1.0 * np.sum(np.min(pdists[:, subset], axis=1)) + loss = 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset)) + candidate_scores.append(marginal_cost + loss_mult * loss) + # remove i from subset. + subset.pop() + + # push i_star to subset. + i_star = candidate_ids[np.argmax(candidate_scores)] + subset.append(i_star) + # remove i_star from candidate indices. + candidate_ids.remove(i_star) + k += 1 + + # Expose labels_ which are the assignments of + # the training data to clusters. + self.labels_ = self._get_cluster_ics(pdists, subset) + # Expose cluster centers, i.e. medoids. + self.cluster_centers_ = feat.take(subset, axis=0) + # Expose indices of chosen cluster centers. + self.center_ics_ = subset + # Expose the score = -\sum_{i \in V} min_{j \in S} || x_i - x_j || + self.score_ = np.float32(-1.0) * self._get_facility_distance(pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + # Expose the chosen cluster indices. + self.subset_ = subset + return self + + def _augmented_update_medoid_ics_in_place(self, pdists, y_gt, cluster_ics, + medoid_ics, loss_mult): + for cluster_idx in range(self.n_clusters): + # y_pred = self._get_cluster_ics(D, medoid_ics) + # Don't prematurely do the assignment step. + # Do this after we've updated all cluster medoids. + y_pred = cluster_ics + + if sum(y_pred == cluster_idx) == 0: + # Cluster is empty. + continue + + curr_score = ( + -1.0 * np.sum( + pdists[medoid_ics[cluster_idx], y_pred == cluster_idx]) + + loss_mult * (1.0 - metrics.normalized_mutual_info_score( + y_gt, y_pred))) + + pdist_in = pdists[y_pred == cluster_idx, :] + pdist_in = pdist_in[:, y_pred == cluster_idx] + + all_scores_fac = np.sum(-1.0 * pdist_in, axis=1) + all_scores_loss = [] + for i in range(y_pred.size): + if y_pred[i] != cluster_idx: + continue + # remove this cluster's current centroid + medoid_ics_i = medoid_ics[:cluster_idx] + medoid_ics[cluster_idx + 1:] + # add this new candidate to the centroid list + medoid_ics_i += [i] + y_pred_i = self._get_cluster_ics(pdists, medoid_ics_i) + all_scores_loss.append(loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score(y_gt, y_pred_i))) + + all_scores = all_scores_fac + all_scores_loss + max_score_idx = np.argmax(all_scores) + max_score = all_scores[max_score_idx] + + if max_score > curr_score: + medoid_ics[cluster_idx] = np.where( + y_pred == cluster_idx)[0][max_score_idx] + + def pam_augmented_fit(self, feat, y, loss_mult): + pam_max_iter = 5 + self._check_init_args() + feat = self._check_array(feat) + pdists = pairwise_distance_np(feat) + self.loss_augmented_fit(feat, y, loss_mult) + print('PAM -1 (before PAM): score: %f, score_aug: %f' % ( + self.score_, self.score_aug_)) + # Initialize from loss augmented facility location + subset = self.center_ics_ + for iter_ in range(pam_max_iter): + # update the cluster assignment + cluster_ics = self._get_cluster_ics(pdists, subset) + # update the medoid for each clusters + self._augmented_update_medoid_ics_in_place(pdists, y, cluster_ics, subset, + loss_mult) + self.score_ = np.float32(-1.0) * self._get_facility_distance( + pdists, subset) + self.score_aug_ = self.score_ + loss_mult * ( + 1.0 - metrics.normalized_mutual_info_score( + y, self._get_cluster_ics(pdists, subset))) + self.score_aug_ = self.score_aug_.astype(np.float32) + print('PAM iter: %d, score: %f, score_aug: %f' % (iter_, self.score_, + self.score_aug_)) + + self.center_ics_ = subset + self.labels_ = cluster_ics + return self + + def _check_array(self, feat): + # Check that the number of clusters is less than or equal to + # the number of samples + if self.n_clusters > feat.shape[0]: + raise ValueError('The number of medoids ' + '({}) '.format( + self.n_clusters) + 'must be larger than the number ' + + 'of samples ({})'.format(feat.shape[0])) + return feat + + def _get_cluster_ics(self, pdists, subset): + """Returns cluster indices for pdist and current medoid indices.""" + # Assign data points to clusters based on + # which cluster assignment yields + # the smallest distance` + cluster_ics = np.argmin(pdists[subset, :], axis=0) + return cluster_ics + + def _get_facility_distance(self, pdists, subset): + return np.sum(np.min(pdists[subset, :], axis=0)) + + +class ClusterLossTest(test.TestCase): + + def _genClusters(self, n_samples, n_clusters): + blobs = datasets.make_blobs( + n_samples=n_samples, centers=n_clusters) + embedding, labels = blobs + embedding = (embedding - embedding.mean(axis=0)) / embedding.std(axis=0) + embedding = embedding.astype(np.float32) + return embedding, labels + + def testClusteringLossPAMOff(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=False) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=False) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + + def testClusteringLossPAMOn(self): + if not HAS_SKLEARN: + return + with self.test_session(): + margin_multiplier = 10.0 + embeddings, labels = self._genClusters(n_samples=128, n_clusters=64) + + loss_np = compute_cluster_loss_numpy( + embeddings, labels, margin_multiplier, enable_pam_finetuning=True) + loss_tf = metric_loss_ops.cluster_loss( + labels=ops.convert_to_tensor(labels), + embeddings=ops.convert_to_tensor(embeddings), + margin_multiplier=margin_multiplier, + enable_pam_finetuning=True) + loss_tf = loss_tf.eval() + self.assertAllClose(loss_np, loss_tf) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/makefile/BUILD b/tensorflow/contrib/makefile/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..a8dd59f32a7f3b27993a7ee48ee7cc07ada59a4c --- /dev/null +++ b/tensorflow/contrib/makefile/BUILD @@ -0,0 +1,31 @@ +# Necessary build rules for makefile build in our CI. + +licenses(["notice"]) # Apache 2.0 + +package(default_visibility = ["//visibility:private"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = ["**/OWNERS"], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +sh_test( + name = "build_all_linux", + size = "enormous", + srcs = ["build_all_linux.sh"], + data = [ + "//tensorflow:all_opensource_files", + "//third_party/eigen3:all_files", + "//third_party/fft2d:all_files", + ], + tags = [ + "manual", + "no_gpu", + "no_oss", + "notap", + ], +) diff --git a/tensorflow/contrib/makefile/Dockerfile b/tensorflow/contrib/makefile/Dockerfile index 341f22e692687fe24f4f4be596180ce0f8b16368..64d571a4edfffd82a82318b797ba1edf96f69027 100644 --- a/tensorflow/contrib/makefile/Dockerfile +++ b/tensorflow/contrib/makefile/Dockerfile @@ -1,6 +1,6 @@ FROM ubuntu:16.04 -MAINTAINER Gunhan Gulsoy +LABEL maintainer="Gunhan Gulsoy " # Install make build dependencies for TensorFlow. RUN apt-get update diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile index e0cfab0b26d8f106e83f6223d057c9ef5f395f4f..dba14646536b077020d861940cf7b1184c651b54 100644 --- a/tensorflow/contrib/makefile/Makefile +++ b/tensorflow/contrib/makefile/Makefile @@ -11,10 +11,15 @@ # the first for the host (the machine you're compiling on) and the second for # the target (the machine you want the program to run on). +SHELL := /bin/bash + # Host compilation settings # Find where we're running from, so we can store generated files here. -MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +ifeq ($(origin MAKEFILE_DIR), undefined) + MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST)))) +endif + HAS_GEN_HOST_PROTOC := \ $(shell test -f $(MAKEFILE_DIR)/gen/protobuf-host/bin/protoc && echo "true" ||\ echo "false") @@ -41,6 +46,11 @@ ifdef HEXAGON_LIBS endif endif # HEXAGON_LIBS +# If ANDROID_TYPES is not set assume __ANDROID_TYPES_SLIM__ +ifeq ($(ANDROID_TYPES),) + ANDROID_TYPES := -D__ANDROID_TYPES_SLIM__ +endif + # Try to figure out the host system HOST_OS := ifeq ($(OS),Windows_NT) @@ -55,6 +65,8 @@ else endif endif +HOST_ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi) + # Where compiled objects are stored. HOST_OBJDIR := $(MAKEFILE_DIR)/gen/host_obj/ HOST_BINDIR := $(MAKEFILE_DIR)/gen/host_bin/ @@ -71,6 +83,7 @@ HOST_LDOPTS += -L/usr/local/lib HOST_INCLUDES := \ -I. \ +-I$(MAKEFILE_DIR)/../../../ \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ @@ -190,6 +203,10 @@ LIBFLAGS := # If we're on OS X, make sure that globals aren't stripped out. ifeq ($(TARGET),OSX) +ifeq ($(HAS_GEN_HOST_PROTOC),true) + LIBFLAGS += -L$(MAKEFILE_DIR)/gen/protobuf-host/lib + export LD_LIBRARY_PATH=$(MAKEFILE_DIR)/gen/protobuf-host/lib +endif LDFLAGS += -all_load endif # Make sure that we don't strip global constructors on Linux. @@ -208,7 +225,7 @@ ifeq ($(TARGET),LINUX) endif # If we're cross-compiling for the Raspberry Pi, use the right gcc. ifeq ($(TARGET),PI) - CXXFLAGS += -D__ANDROID_TYPES_SLIM__ -DRASPBERRY_PI + CXXFLAGS += $(ANDROID_TYPES) -DRASPBERRY_PI LDFLAGS := -Wl,--no-whole-archive LIBS += -ldl -lpthread LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive @@ -222,43 +239,93 @@ ifeq ($(TARGET),ANDROID) # NDK_ROOT=/path/to/your/ndk # You need to have an Android version of the protobuf libraries compiled to link # in. The compile_android_protobuf.sh script may help. -# TODO(satok): Support all CPU architectures (Currently only armv7 is supported) - OS_PATH := + ANDROID_HOST_OS_ARCH := ifeq ($(HOST_OS),LINUX) - OS_PATH=linux + ANDROID_HOST_OS_ARCH=linux endif ifeq ($(HOST_OS),OSX) - OS_PATH=darwin + ANDROID_HOST_OS_ARCH=darwin endif ifeq ($(HOST_OS),WINDOWS) $(error "windows is not supported.") endif + ifeq ($(HOST_ARCH),x86_32) + ANDROID_HOST_OS_ARCH := $(ANDROID_HOST_OS_ARCH)-x86 + else + ANDROID_HOST_OS_ARCH := $(ANDROID_HOST_OS_ARCH)-$(HOST_ARCH) + endif + + ifndef ANDROID_ARCH + ANDROID_ARCH := armeabi-v7a + endif + + ifeq ($(ANDROID_ARCH),arm64-v8a) + TOOLCHAIN := aarch64-linux-android-4.9 + SYSROOT_ARCH := arm64 + BIN_PREFIX := aarch64-linux-android + MARCH_OPTION := + endif + ifeq ($(ANDROID_ARCH),armeabi) + TOOLCHAIN := arm-linux-androideabi-4.9 + SYSROOT_ARCH := arm + BIN_PREFIX := arm-linux-androideabi + MARCH_OPTION := + endif + ifeq ($(ANDROID_ARCH),armeabi-v7a) + TOOLCHAIN := arm-linux-androideabi-4.9 + SYSROOT_ARCH := arm + BIN_PREFIX := arm-linux-androideabi + MARCH_OPTION := -march=armv7-a -mfloat-abi=softfp -mfpu=neon + endif + ifeq ($(ANDROID_ARCH),mips) + TOOLCHAIN := mipsel-linux-android-4.9 + SYSROOT_ARCH := mips + BIN_PREFIX := mipsel-linux-android + MARCH_OPTION := + endif + ifeq ($(ANDROID_ARCH),mips64) + TOOLCHAIN := mips64el-linux-android-4.9 + SYSROOT_ARCH := mips64 + BIN_PREFIX := mips64el-linux-android + MARCH_OPTION := + endif + ifeq ($(ANDROID_ARCH),x86) + TOOLCHAIN := x86-4.9 + SYSROOT_ARCH := x86 + BIN_PREFIX := i686-linux-android + MARCH_OPTION := + endif + ifeq ($(ANDROID_ARCH),x86_64) + TOOLCHAIN := x86_64-4.9 + SYSROOT_ARCH := x86_64 + BIN_PREFIX := x86-64-linux-android + MARCH_OPTION := + endif + ifndef NDK_ROOT $(error "NDK_ROOT is not defined.") endif - CXX := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-g++ - CC := $(CC_PREFIX) $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-gcc + CXX := $(CC_PREFIX) $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-g++ + CC := $(CC_PREFIX) $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-gcc CXXFLAGS +=\ ---sysroot $(NDK_ROOT)/platforms/android-21/arch-arm \ +--sysroot $(NDK_ROOT)/platforms/android-21/arch-$(SYSROOT_ARCH) \ -Wno-narrowing \ -fomit-frame-pointer \ --march=armv7-a \ --mfloat-abi=softfp \ --mfpu=neon \ +$(MARCH_OPTION) \ -fPIE INCLUDES = \ -I$(NDK_ROOT)/sources/android/support/include \ -I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/include \ --I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/armeabi/include \ +-I$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(ANDROID_ARCH)/include \ -I. \ -I$(MAKEFILE_DIR)/downloads/ \ -I$(MAKEFILE_DIR)/downloads/eigen \ -I$(MAKEFILE_DIR)/downloads/gemmlowp \ -I$(MAKEFILE_DIR)/downloads/nsync/public \ -I$(MAKEFILE_DIR)/downloads/fft2d \ --I$(MAKEFILE_DIR)/gen/protobuf/include \ +-I$(MAKEFILE_DIR)/gen/protobuf_android/$(ANDROID_ARCH)/include \ -I$(PROTOGENDIR) \ -I$(PBTGENDIR) @@ -269,19 +336,20 @@ $(TARGET_NSYNC_LIB) \ -llog \ -lz \ -lm \ --ldl +-ldl \ +-latomic - LD := $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/arm-linux-androideabi/bin/ld + LD := $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/$(BIN_PREFIX)/bin/ld LDFLAGS := \ --march=armv7-a \ --L$(MAKEFILE_DIR)/gen/protobuf/lib \ --L$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/armeabi-v7a \ +$(MARCH_OPTION) \ +-L$(MAKEFILE_DIR)/gen/protobuf_android/$(ANDROID_ARCH)/lib \ +-L$(NDK_ROOT)/sources/cxx-stl/gnu-libstdc++/4.9/libs/$(ANDROID_ARCH) \ -fPIE \ -pie \ -v - AR := $(NDK_ROOT)/toolchains/arm-linux-androideabi-4.9/prebuilt/$(OS_PATH)-x86_64/bin/arm-linux-androideabi-ar + AR := $(NDK_ROOT)/toolchains/$(TOOLCHAIN)/prebuilt/$(ANDROID_HOST_OS_ARCH)/bin/$(BIN_PREFIX)-ar ARFLAGS := r LIBFLAGS += -Wl,--allow-multiple-definition -Wl,--whole-archive @@ -305,6 +373,11 @@ $(TARGET_NSYNC_LIB) \ ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS CXXFLAGS += -DENABLE_EXPERIMENTAL_HEXNN_OPS endif + + OBJDIR := $(OBJDIR)android_$(ANDROID_ARCH)/ + LIBDIR := $(LIBDIR)android_$(ANDROID_ARCH)/ + BINDIR := $(BINDIR)android_$(ANDROID_ARCH)/ + DEPDIR := $(DEPDIR)android_$(ANDROID_ARCH)/ endif # ANDROID # LINT.ThenChange(//tensorflow/contrib/android/cmake/CMakeLists.txt) @@ -330,7 +403,7 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -354,7 +427,7 @@ ifeq ($(TARGET),IOS) -Wno-c++11-narrowing \ -mno-thumb \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -377,7 +450,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONEOS_SYSROOT} @@ -401,7 +474,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONESIMULATOR_SYSROOT} @@ -424,7 +497,7 @@ ifeq ($(TARGET),IOS) -DUSE_GEMM_FOR_CONV \ -Wno-c++11-narrowing \ -DTF_LEAN_BINARY \ - -D__ANDROID_TYPES_SLIM__ \ + $(ANDROID_TYPES) \ -fno-exceptions \ -isysroot \ ${IPHONESIMULATOR_SYSROOT} @@ -484,6 +557,7 @@ $(wildcard tensorflow/core/*/*/*main.cc) \ $(wildcard tensorflow/core/debug/*.cc) \ $(wildcard tensorflow/core/framework/op_gen_lib.cc) \ $(wildcard tensorflow/core/graph/dot.*) \ +$(wildcard tensorflow/core/lib/db/*) \ $(wildcard tensorflow/core/lib/gif/*) \ $(wildcard tensorflow/core/lib/io/zlib*) \ $(wildcard tensorflow/core/lib/io/record*) \ @@ -501,6 +575,7 @@ $(wildcard tensorflow/core/platform/google/*) \ $(wildcard tensorflow/core/platform/google/*/*) \ $(wildcard tensorflow/core/platform/jpeg.*) \ $(wildcard tensorflow/core/platform/png.*) \ +$(wildcard tensorflow/core/platform/s3/*) \ $(wildcard tensorflow/core/platform/stream_executor.*) \ $(wildcard tensorflow/core/platform/windows/*) \ $(wildcard tensorflow/core/user_ops/*.cu.cc) \ @@ -645,12 +720,12 @@ clean: # Gets rid of all generated files except protobuf libs generated # before calling make. This allows users not to recompile proto libs everytime. clean_except_protobuf_libs: - find $(MAKEFILE_DIR)/gen -mindepth 1 -maxdepth 1 ! -name "protobuf" ! -name "protobuf-host" -exec rm -r "{}" \; + find $(MAKEFILE_DIR)/gen -mindepth 1 -maxdepth 1 ! -name "protobuf*" -exec rm -r "{}" \; rm -rf tensorflow/core/util/version_info.cc # Gets rid of target files only, leaving the host alone. Also leaves the lib # directory untouched deliberately, so we can persist multiple architectures -# across builds for iOS. +# across builds for iOS and Android. cleantarget: rm -rf $(OBJDIR) rm -rf $(BINDIR) diff --git a/tensorflow/contrib/makefile/build_all_android.sh b/tensorflow/contrib/makefile/build_all_android.sh index 9944f71950ac59ba147bf33c344c3478cdd175be..81cb17a311fd94aa397eb7a766cd8c668268759a 100755 --- a/tensorflow/contrib/makefile/build_all_android.sh +++ b/tensorflow/contrib/makefile/build_all_android.sh @@ -18,12 +18,15 @@ set -e usage() { - echo "Usage: NDK_ROOT= $(basename "$0") [-s:t:Tx:X]" + echo "Usage: NDK_ROOT= $(basename "$0") [-Es:t:Tx:a:X]" echo "-E enable experimental hexnn ops" echo "-s [sub_makefiles] sub makefiles separated by white space" echo "-t [build_target] build target for Android makefile [default=all]" echo "-T only build tensorflow" echo "-x [hexagon library path] copy and hexagon libraries in the specified path" + echo "-a [architecture] Architecture of target android [default=armeabi-v7a] \ +(supported architecture list: \ +arm64-v8a armeabi armeabi-v7a mips mips64 x86 x86_64)" exit 1 } @@ -32,13 +35,16 @@ if [[ -z "${NDK_ROOT}" ]]; then exit 1 fi -while getopts "Es:t:Tx:" opt_name; do +ARCH=armeabi-v7a + +while getopts "Es:t:Tx:a:" opt_name; do case "$opt_name" in E) ENABLE_EXPERIMENTAL_HEXNN_OPS="true";; s) SUB_MAKEFILES="${OPTARG}";; t) BUILD_TARGET="${OPTARG}";; T) ONLY_MAKE_TENSORFLOW="true";; x) HEXAGON_LIB_PATH="${OPTARG}";; + a) ARCH="${OPTARG}";; *) usage;; esac done @@ -53,25 +59,23 @@ JOB_COUNT="${JOB_COUNT:-$(get_job_count)}" HEXAGON_DOWNLOAD_PATH="tensorflow/contrib/makefile/downloads/hexagon" +# Remove any old files first. +make -f tensorflow/contrib/makefile/Makefile cleantarget + if [[ "${ONLY_MAKE_TENSORFLOW}" != "true" ]]; then - # Remove any old files first. - make -f tensorflow/contrib/makefile/Makefile clean rm -rf tensorflow/contrib/makefile/downloads # Pull down the required versions of the frameworks we need. tensorflow/contrib/makefile/download_dependencies.sh # Compile protobuf for the target Android device architectures. CC_PREFIX="${CC_PREFIX}" NDK_ROOT="${NDK_ROOT}" \ -tensorflow/contrib/makefile/compile_android_protobuf.sh -c -else - # Only clean files generated by make - make -f tensorflow/contrib/makefile/Makefile clean_except_protobuf_libs +tensorflow/contrib/makefile/compile_android_protobuf.sh -c -a ${ARCH} fi # Compile nsync for the host and the target Android device architecture. # Don't use export var=`something` syntax; it swallows the exit status. HOST_NSYNC_LIB=`tensorflow/contrib/makefile/compile_nsync.sh` TARGET_NSYNC_LIB=`CC_PREFIX="${CC_PREFIX}" NDK_ROOT="${NDK_ROOT}" \ - tensorflow/contrib/makefile/compile_nsync.sh -t android -a armeabi-v7a` + tensorflow/contrib/makefile/compile_nsync.sh -t android -a ${ARCH}` export HOST_NSYNC_LIB TARGET_NSYNC_LIB if [[ ! -z "${HEXAGON_LIB_PATH}" ]]; then @@ -98,7 +102,8 @@ fi if [[ -z "${BUILD_TARGET}" ]]; then make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ - TARGET=ANDROID NDK_ROOT="${NDK_ROOT}" CC_PREFIX="${CC_PREFIX}" \ + TARGET=ANDROID NDK_ROOT="${NDK_ROOT}" ANDROID_ARCH="${ARCH}" \ + CC_PREFIX="${CC_PREFIX}" \ HOST_NSYNC_LIB="$HOST_NSYNC_LIB" TARGET_NSYNC_LIB="$TARGET_NSYNC_LIB" \ HEXAGON_LIBS="${HEXAGON_LIBS}" HEXAGON_INCLUDE="${HEXAGON_INCLUDE}" \ SUB_MAKEFILES="${SUB_MAKEFILES}" ${EXTRA_MAKE_ARGS[@]} @@ -106,7 +111,8 @@ else # BUILD_TARGET explicitly uncommented to allow multiple targets to be # passed to make in a single build_all_android.sh invocation. make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ - TARGET=ANDROID NDK_ROOT="${NDK_ROOT}" CC_PREFIX="${CC_PREFIX}" \ + TARGET=ANDROID NDK_ROOT="${NDK_ROOT}" ANDROID_ARCH="${ARCH}" \ + CC_PREFIX="${CC_PREFIX}" \ HOST_NSYNC_LIB="$HOST_NSYNC_LIB" TARGET_NSYNC_LIB="$TARGET_NSYNC_LIB" \ HEXAGON_LIBS="${HEXAGON_LIBS}" HEXAGON_INCLUDE="${HEXAGON_INCLUDE}" \ SUB_MAKEFILES="${SUB_MAKEFILES}" ${EXTRA_MAKE_ARGS[@]} ${BUILD_TARGET} diff --git a/tensorflow/contrib/makefile/build_all_linux.sh b/tensorflow/contrib/makefile/build_all_linux.sh index 5d73f697f4ef0b2a566deb04397b0def5a442cfa..a440633cfc23a7c606586a3b53180aaed6fe27ad 100755 --- a/tensorflow/contrib/makefile/build_all_linux.sh +++ b/tensorflow/contrib/makefile/build_all_linux.sh @@ -44,4 +44,5 @@ tensorflow/contrib/makefile/compile_linux_protobuf.sh # Build TensorFlow. make -j"${JOB_COUNT}" -f tensorflow/contrib/makefile/Makefile \ OPTFLAGS="-O3 -march=native" \ - HOST_CXXFLAGS="--std=c++11 -march=native" + HOST_CXXFLAGS="--std=c++11 -march=native" \ + MAKEFILE_DIR=$SCRIPT_DIR diff --git a/tensorflow/contrib/makefile/compile_android_protobuf.sh b/tensorflow/contrib/makefile/compile_android_protobuf.sh index fadbe271b85e6812953c1a00345a5e7f92bf9dbe..4355e3e5974e7ec4626773feca808631f2dbf1a8 100755 --- a/tensorflow/contrib/makefile/compile_android_protobuf.sh +++ b/tensorflow/contrib/makefile/compile_android_protobuf.sh @@ -71,10 +71,10 @@ then exit 1 fi -GENDIR="$(pwd)/gen/protobuf" +GENDIR="$(pwd)/gen/protobuf_android" HOST_GENDIR="$(pwd)/gen/protobuf-host" mkdir -p "${GENDIR}" -mkdir -p "${HOST_GENDIR}" +mkdir -p "${GENDIR}/${ARCHITECTURE}" if [[ ! -f "./downloads/protobuf/autogen.sh" ]]; then echo "You need to download dependencies before running this script." 1>&2 @@ -153,7 +153,7 @@ then exit 1 fi -./configure --prefix="${GENDIR}" \ +./configure --prefix="${GENDIR}/${ARCHITECTURE}" \ --host="${bin_prefix}" \ --with-sysroot="${SYSROOT}" \ --disable-shared \ diff --git a/tensorflow/contrib/makefile/download_dependencies.sh b/tensorflow/contrib/makefile/download_dependencies.sh index 39c89628d96ad1d7d8a28ec76071d4aa31085225..a2b444d53aeb5738786483a451bcc529686a92fd 100755 --- a/tensorflow/contrib/makefile/download_dependencies.sh +++ b/tensorflow/contrib/makefile/download_dependencies.sh @@ -20,12 +20,13 @@ DOWNLOADS_DIR=tensorflow/contrib/makefile/downloads BZL_FILE_PATH=tensorflow/workspace.bzl EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" -GEMMLOWP_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" +GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)" GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz" -NSYNC_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -PROTOBUF_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" -RE2_URL="$(grep -o 'http://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +NSYNC_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/nsync/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +PROTOBUF_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/protobuf/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" +RE2_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/re2/.*tar\.gz' "${BZL_FILE_PATH}" | head -n1)" FFT2D_URL="$(grep -o 'http.*fft\.tgz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)" +ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)" # TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64, # so work around it by patching the source. @@ -54,7 +55,7 @@ download_and_extract() { elif [[ "${url}" == *zip ]]; then tempdir=$(mktemp -d) tempdir2=$(mktemp -d) - wget ${url} -P ${tempdir} + wget -P ${tempdir} ${url} unzip ${tempdir}/* -d ${tempdir2} # unzip has no strip components, so unzip to a temp dir, and move the files # we want from the tempdir to destination. @@ -73,6 +74,7 @@ download_and_extract "${NSYNC_URL}" "${DOWNLOADS_DIR}/nsync" download_and_extract "${PROTOBUF_URL}" "${DOWNLOADS_DIR}/protobuf" download_and_extract "${RE2_URL}" "${DOWNLOADS_DIR}/re2" download_and_extract "${FFT2D_URL}" "${DOWNLOADS_DIR}/fft2d" +download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl" replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \ "${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h" diff --git a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt index 5ade8942af39f1d308c5f6e308e1cee754510926..938c4a53ab3fff72b028276eac5aad76ff01880d 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_cc_files.txt @@ -24,6 +24,7 @@ tensorflow/core/framework/summary.pb.cc tensorflow/core/framework/step_stats.pb.cc tensorflow/core/framework/resource_handle.pb.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb.cc +tensorflow/core/framework/api_def.pb.cc tensorflow/core/framework/op_def.pb.cc tensorflow/core/framework/node_def.pb.cc tensorflow/core/framework/log_memory.pb.cc diff --git a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt index 1f0ad06cdc5b98ae9c08ea63dad70eb02b6ef46b..aa91b2f954504c42d33838c728abd666ef100e14 100644 --- a/tensorflow/contrib/makefile/proto_text_pb_h_files.txt +++ b/tensorflow/contrib/makefile/proto_text_pb_h_files.txt @@ -25,6 +25,7 @@ tensorflow/core/framework/summary.pb.h tensorflow/core/framework/step_stats.pb.h tensorflow/core/framework/resource_handle.pb.h tensorflow/core/framework/remote_fused_graph_execute_info.pb.h +tensorflow/core/framework/api_def.pb.h tensorflow/core/framework/op_def.pb.h tensorflow/core/framework/node_def.pb.h tensorflow/core/framework/log_memory.pb.h diff --git a/tensorflow/contrib/makefile/tf_op_files.txt b/tensorflow/contrib/makefile/tf_op_files.txt index a7f2be9790d2b3e8fd7163f4a9f1b14a9519f2c6..8b77c99cb574123c2af5d8f9f17cd403613cfffd 100644 --- a/tensorflow/contrib/makefile/tf_op_files.txt +++ b/tensorflow/contrib/makefile/tf_op_files.txt @@ -1,4 +1,3 @@ -tensorflow/contrib/boosted_trees/ops/ensemble_optimizer_ops.cc tensorflow/contrib/boosted_trees/ops/model_ops.cc tensorflow/contrib/boosted_trees/ops/prediction_ops.cc tensorflow/contrib/boosted_trees/ops/quantile_ops.cc @@ -143,6 +142,7 @@ tensorflow/core/kernels/cwise_op_sqrt.cc tensorflow/core/kernels/cwise_op_sigmoid.cc tensorflow/core/kernels/cwise_op_sign.cc tensorflow/core/kernels/cwise_op_select.cc +tensorflow/core/kernels/cwise_op_round.cc tensorflow/core/kernels/cwise_op_rsqrt.cc tensorflow/core/kernels/cwise_op_reciprocal.cc tensorflow/core/kernels/cwise_op_neg.cc @@ -161,6 +161,7 @@ tensorflow/core/kernels/cwise_op_invert.cc tensorflow/core/kernels/cwise_op_greater_equal.cc tensorflow/core/kernels/cwise_op_greater.cc tensorflow/core/kernels/cwise_op_floor_div.cc +tensorflow/core/kernels/cwise_op_floor_mod.cc tensorflow/core/kernels/cwise_op_floor.cc tensorflow/core/kernels/cwise_op_exp.cc tensorflow/core/kernels/cwise_op_equal_to_2.cc @@ -169,6 +170,8 @@ tensorflow/core/kernels/cwise_op_div.cc tensorflow/core/kernels/cwise_op_bitwise_xor.cc tensorflow/core/kernels/cwise_op_bitwise_or.cc tensorflow/core/kernels/cwise_op_bitwise_and.cc +tensorflow/core/kernels/cwise_op_left_shift.cc +tensorflow/core/kernels/cwise_op_right_shift.cc tensorflow/core/kernels/cwise_op_add_2.cc tensorflow/core/kernels/cwise_op_add_1.cc tensorflow/core/kernels/cwise_op_abs.cc @@ -261,3 +264,4 @@ tensorflow/core/kernels/spacetobatch_functor.cc tensorflow/core/kernels/spacetobatch_op.cc tensorflow/core/kernels/batchtospace_op.cc tensorflow/core/kernels/warn_about_ints.cc +tensorflow/core/kernels/segment_reduction_ops.cc diff --git a/tensorflow/contrib/makefile/tf_pb_text_files.txt b/tensorflow/contrib/makefile/tf_pb_text_files.txt index c39257ffa91fef184e8bd5258b19c4323a1b7fe0..b5431df2eb016d010c51bdbb33fd747b3569ce83 100644 --- a/tensorflow/contrib/makefile/tf_pb_text_files.txt +++ b/tensorflow/contrib/makefile/tf_pb_text_files.txt @@ -17,6 +17,7 @@ tensorflow/core/framework/summary.pb_text.cc tensorflow/core/framework/step_stats.pb_text.cc tensorflow/core/framework/resource_handle.pb_text.cc tensorflow/core/framework/remote_fused_graph_execute_info.pb_text.cc +tensorflow/core/framework/api_def.pb_text.cc tensorflow/core/framework/op_def.pb_text.cc tensorflow/core/framework/node_def.pb_text.cc tensorflow/core/framework/log_memory.pb_text.cc diff --git a/tensorflow/contrib/makefile/tf_proto_files.txt b/tensorflow/contrib/makefile/tf_proto_files.txt index a1a9aa7190205d9f3c34ef01b65db85f89f2ac85..d569bde637b20e0ca55c48c616855332abd9fb13 100644 --- a/tensorflow/contrib/makefile/tf_proto_files.txt +++ b/tensorflow/contrib/makefile/tf_proto_files.txt @@ -30,6 +30,7 @@ tensorflow/core/framework/step_stats.proto tensorflow/core/framework/resource_handle.proto tensorflow/core/framework/remote_fused_graph_execute_info.proto tensorflow/core/framework/reader_base.proto +tensorflow/core/framework/api_def.proto tensorflow/core/framework/op_def.proto tensorflow/core/framework/node_def.proto tensorflow/core/framework/log_memory.proto diff --git a/tensorflow/contrib/memory_stats/BUILD b/tensorflow/contrib/memory_stats/BUILD index 8b9d30dcfd088902ded36c7513ffc419e6bf7c7a..72424c32e7b756e6c50965f38135869e03ba730f 100644 --- a/tensorflow/contrib/memory_stats/BUILD +++ b/tensorflow/contrib/memory_stats/BUILD @@ -63,6 +63,8 @@ tf_custom_op_py_library( deps = [ ":memory_stats_ops", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:platform", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/memory_stats/__init__.py b/tensorflow/contrib/memory_stats/__init__.py index a2b2b65692917ac73b62219f71bd5be677234673..a32302c854b68ed1b211a221f3026e8d5b6091ac 100644 --- a/tensorflow/contrib/memory_stats/__init__.py +++ b/tensorflow/contrib/memory_stats/__init__.py @@ -14,10 +14,12 @@ # ============================================================================== """Ops for memory statistics. +@@BytesInUse @@BytesLimit @@MaxBytesInUse """ +from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesInUse from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import BytesLimit from tensorflow.contrib.memory_stats.python.ops.memory_stats_ops import MaxBytesInUse diff --git a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc index 3b88535dce4931fabb3bfcd9e7a52f7ef09d3252..7e2e96e160167ae68d3bdabacbbbeb45df61778f 100644 --- a/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/kernels/memory_stats_ops.cc @@ -40,6 +40,28 @@ class MemoryStatsOp : public OpKernel { const AllocatorStats& allocator_stats) const = 0; }; +// Op that measures current memory in bytes. +class BytesInUseOp : public MemoryStatsOp { + public: + explicit BytesInUseOp(OpKernelConstruction* context) + : MemoryStatsOp(context) {} + + private: + int64 ExtractAllocatorStats( + const AllocatorStats& allocator_stats) const override { + return allocator_stats.bytes_in_use; + } +}; + +// Register this op on GPU only, see comment for MaxBytesInUse for reason +REGISTER_KERNEL_BUILDER(Name("BytesInUse").Device(DEVICE_GPU).HostMemory("out"), + BytesInUseOp); + +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("BytesInUse").Device(DEVICE_SYCL).HostMemory("out"), MaxBytesInUseOp); +#endif // TENSORFLOW_USE_SYCL + // Op that measures the total memory (in bytes) of a device. class BytesLimitOp : public MemoryStatsOp { public: diff --git a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc index 08859c86135eeee395bfde5f814f5ae5c033c385..42020cf7f6b98ce883e5b0128ba5b08127d03b90 100644 --- a/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc +++ b/tensorflow/contrib/memory_stats/ops/memory_stats_ops.cc @@ -17,6 +17,10 @@ limitations under the License. namespace tensorflow { +REGISTER_OP("BytesInUse") + .Output("out: int64") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); REGISTER_OP("BytesLimit") .Output("out: int64") .SetIsStateful() diff --git a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py index ec25c032f0588e5aaa0192349288d45e503baecf..d1b430b8039fcf7e10bcb842c3f34b960b9026b3 100644 --- a/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py +++ b/tensorflow/contrib/memory_stats/python/kernel_tests/memory_stats_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.memory_stats.python.ops import memory_stats_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import math_ops @@ -64,10 +65,29 @@ class MemoryStatsOpsTest(test_util.TensorFlowTestCase): d = math_ops.matmul(c, b) sess.run(d) - max_bytes_in_use = sess.run(memory_stats_ops.MaxBytesInUse()) + max_bytes_in_use_op = memory_stats_ops.MaxBytesInUse() + max_bytes_in_use = sess.run(max_bytes_in_use_op) self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3) self.assertLess(max_bytes_in_use, matrix_size_in_bytes * 4) + # run chain with 2 ops, make sure BytesInUse captures intermediate + # memory usage + a = random_ops.random_uniform(matrix_shape, dtype=dtype) + with ops.control_dependencies([a]): + bytes_in_use_op = memory_stats_ops.BytesInUse() + with ops.control_dependencies([bytes_in_use_op]): + b = random_ops.random_uniform(matrix_shape, dtype=dtype) + + _, bytes_in_use, max_bytes_in_use = sess.run([a, bytes_in_use_op, + max_bytes_in_use_op]) + + # intermediate result allocates 1 matrix, max usage is at least 2 + self.assertGreaterEqual(bytes_in_use, matrix_size_in_bytes * 1) + self.assertLess(bytes_in_use, matrix_size_in_bytes * 2) + + # max usage is still 3 because it reflects maxium from previous .run call + self.assertGreaterEqual(max_bytes_in_use, matrix_size_in_bytes * 3) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py index d35c6583ed05ac3b50f79b01d4db98baa0b5442c..c0f7788c1c6ae86d7ea34d54f07a43d358abb951 100644 --- a/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py +++ b/tensorflow/contrib/memory_stats/python/ops/memory_stats_ops.py @@ -26,6 +26,11 @@ _memory_stats_ops_so = loader.load_op_library( resource_loader.get_path_to_datafile("_memory_stats_ops.so")) +def BytesInUse(): + """Generates an op that computes the current memory of a device.""" + return gen_memory_stats_ops.bytes_in_use() + + def BytesLimit(): """Generates an op that measures the total memory (in bytes) of a device.""" return gen_memory_stats_ops.bytes_limit() diff --git a/tensorflow/contrib/meta_graph_transform/BUILD b/tensorflow/contrib/meta_graph_transform/BUILD index d47ac5bcfe002ca8aaf4b8130c7b7fd58d1faeb9..4b5b1c3e15d36b7602791856416ece54d24798b2 100644 --- a/tensorflow/contrib/meta_graph_transform/BUILD +++ b/tensorflow/contrib/meta_graph_transform/BUILD @@ -21,7 +21,12 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:ops", + "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", + "//tensorflow/python:graph_util", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python/saved_model:constants", "//tensorflow/tools/graph_transforms:transform_graph_py", ], diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index ff4afbb4ce13c4b26154cd01af3a9f44a0695d58..2932ae1c8df32cd936cff932b061571c513fda79 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -706,7 +706,8 @@ def meta_graph_transform( output_names: Names of output nodes. transforms: A list of strings naming the graph transforms to be applied in order. These transform names are exactly those supported by the Graph - Transform Tool, with the addition of the 'freeze_graph' transform. + Transform Tool, with the addition of the 'freeze_graph' and + 'sparsify_gather' transforms. tags: A list of tags with which to annotate the transformed MetaGraphDef. checkpoint_path: A path to a checkpoint to restore during freezing, if needed (default None). @@ -748,7 +749,7 @@ def meta_graph_transform( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) - # Append newly added initalizers to collection. + # Append newly added initializers to collection. _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) # Copy signature_defs, excluding any pruned nodes diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD index e11dff08f853139fa19dd1dc418c4d3ac965ce71..9de664c822bf7a9abf7b8082f444c61dfa45f499 100644 --- a/tensorflow/contrib/metrics/BUILD +++ b/tensorflow/contrib/metrics/BUILD @@ -42,6 +42,7 @@ py_library( "//tensorflow/python:state_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:weights_broadcast_ops", ], ) diff --git a/tensorflow/contrib/metrics/__init__.py b/tensorflow/contrib/metrics/__init__.py index 4c16fb50407c0d81665fb35d2265d078805475a6..302042c4dd6ad294238672b11ce51dd8e255d919 100644 --- a/tensorflow/contrib/metrics/__init__.py +++ b/tensorflow/contrib/metrics/__init__.py @@ -22,6 +22,10 @@ See the @{$python/contrib.metrics} guide. @@streaming_recall_at_thresholds @@streaming_precision @@streaming_precision_at_thresholds +@@streaming_false_positive_rate +@@streaming_false_positive_rate_at_thresholds +@@streaming_false_negative_rate +@@streaming_false_negative_rate_at_thresholds @@streaming_auc @@streaming_curve_points @@streaming_recall_at_k @@ -51,6 +55,7 @@ See the @{$python/contrib.metrics} guide. @@streaming_true_negatives_at_thresholds @@streaming_true_positives @@streaming_true_positives_at_thresholds +@@sparse_recall_at_top_k @@auc_using_histogram @@accuracy @@aggregate_metrics @@ -60,6 +65,8 @@ See the @{$python/contrib.metrics} guide. @@set_intersection @@set_size @@set_union +@@count +@@recall_at_precision """ from __future__ import absolute_import @@ -73,13 +80,20 @@ from tensorflow.contrib.metrics.python.ops.confusion_matrix_ops import confusion from tensorflow.contrib.metrics.python.ops.histogram_ops import auc_using_histogram from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metric_map from tensorflow.contrib.metrics.python.ops.metric_ops import aggregate_metrics +from tensorflow.contrib.metrics.python.ops.metric_ops import count +from tensorflow.contrib.metrics.python.ops.metric_ops import recall_at_precision +from tensorflow.contrib.metrics.python.ops.metric_ops import sparse_recall_at_top_k from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_accuracy from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_auc from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_concat from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_covariance from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_curve_points +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negative_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_negatives_at_thresholds +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate +from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positive_rate_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_false_positives_at_thresholds from tensorflow.contrib.metrics.python.ops.metric_ops import streaming_mean diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops.py b/tensorflow/contrib/metrics/python/ops/metric_ops.py index 76986d0156dada75abcc559d9db6b9addf26cccc..33377a70c2506261b497c1b0fe8ab5ba0c680c7e 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops.py @@ -22,11 +22,12 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections as collections_lib + from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import confusion_matrix from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import metrics @@ -37,6 +38,9 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import weights_broadcast_ops from tensorflow.python.util.deprecation import deprecated +# Epsilon constant used to represent extremely small quantity. +_EPSILON = 1e-7 + def _safe_div(numerator, denominator, name): """Divides two values, returning 0 if the denominator is <= 0. @@ -56,38 +60,15 @@ def _safe_div(numerator, denominator, name): name=name) -def _create_local(name, shape, collections=None, validate_shape=True, - dtype=dtypes.float32): - """Creates a new local variable. - - Args: - name: The name of the new or existing variable. - shape: Shape of the new or existing variable. - collections: A list of collection names to which the Variable will be added. - validate_shape: Whether to validate the shape of the variable. - dtype: Data type of the variables. - - Returns: - The created variable. - """ - # Make sure local variables are added to tf.GraphKeys.LOCAL_VARIABLES - collections = list(collections or []) - collections += [ops.GraphKeys.LOCAL_VARIABLES] - return variable_scope.variable( - initial_value=array_ops.zeros(shape, dtype=dtype), - name=name, - trainable=False, - collections=collections, - validate_shape=validate_shape) - - # TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. def _assert_weights_rank(weights, values): """`weights` rank must be either `0`, or the same as 'values'.""" return check_ops.assert_rank_in(weights, (0, array_ops.rank(values))) -def _count_condition(values, weights=None, metrics_collections=None, +def _count_condition(values, + weights=None, + metrics_collections=None, updates_collections=None): """Sums the weights of cases where the given values are True. @@ -114,7 +95,7 @@ def _count_condition(values, weights=None, metrics_collections=None, or tuple. """ check_ops.assert_type(values, dtypes.bool) - count = _create_local('count', shape=[]) + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') values = math_ops.to_float(values) if weights is not None: @@ -122,8 +103,8 @@ def _count_condition(values, weights=None, metrics_collections=None, with ops.control_dependencies((_assert_weights_rank(weights, values),)): values = math_ops.multiply(values, weights) - value_tensor = array_ops.identity(count) - update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) + value_tensor = array_ops.identity(count_) + update_op = state_ops.assign_add(count_, math_ops.reduce_sum(values)) if metrics_collections: ops.add_to_collections(metrics_collections, value_tensor) @@ -134,7 +115,9 @@ def _count_condition(values, weights=None, metrics_collections=None, return value_tensor, update_op -def streaming_true_positives(predictions, labels, weights=None, +def streaming_true_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -168,12 +151,17 @@ def streaming_true_positives(predictions, labels, weights=None, tuple. """ return metrics.true_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_true_negatives(predictions, labels, weights=None, +def streaming_true_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -206,20 +194,22 @@ def streaming_true_negatives(predictions, labels, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'true_negatives', (predictions, labels, weights)): + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions=math_ops.cast(predictions, dtype=dtypes.bool), labels=math_ops.cast(labels, dtype=dtypes.bool), weights=weights) - is_true_negative = math_ops.logical_and(math_ops.equal(labels, False), - math_ops.equal(predictions, False)) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) return _count_condition(is_true_negative, weights, metrics_collections, updates_collections) -def streaming_false_positives(predictions, labels, weights=None, +def streaming_false_positives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -253,12 +243,17 @@ def streaming_false_positives(predictions, labels, weights=None, tuple. """ return metrics.false_positives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_false_negatives(predictions, labels, weights=None, +def streaming_false_negatives(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -291,9 +286,12 @@ def streaming_false_negatives(predictions, labels, weights=None, or tuple. """ return metrics.false_negatives( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Move this somewhere common, to share with ops/losses/losses.py. @@ -317,17 +315,18 @@ def _broadcast_weights(weights, values): with ops.name_scope(None, 'broadcast_weights', (values, weights)) as scope: weights_shape = weights.get_shape() values_shape = values.get_shape() - if (weights_shape.is_fully_defined() and - values_shape.is_fully_defined() and + if (weights_shape.is_fully_defined() and values_shape.is_fully_defined() and weights_shape.is_compatible_with(values_shape)): return weights with ops.control_dependencies((_assert_weights_rank(weights, values),)): - return math_ops.multiply( - weights, array_ops.ones_like(values), name=scope) + return math_ops.multiply(weights, array_ops.ones_like(values), name=scope) -def streaming_mean(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the (weighted) mean of the given values. The `streaming_mean` function creates two local variables, `total` and `count` @@ -365,12 +364,18 @@ def streaming_mean(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_tensor(values, weights=None, metrics_collections=None, - updates_collections=None, name=None): +def streaming_mean_tensor(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the element-wise (weighted) mean of the given tensors. In contrast to the `streaming_mean` function which returns a scalar with the @@ -412,12 +417,19 @@ def streaming_mean_tensor(values, weights=None, metrics_collections=None, or tuple. """ return metrics.mean_tensor( - values=values, weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) - + values=values, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_accuracy(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +@deprecated(None, "Please switch to tf.metrics.accuracy. Note that the order " + "of the inputs of labels and predictions have been switched.") +def streaming_accuracy(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Calculates how often `predictions` matches `labels`. @@ -462,13 +474,19 @@ def streaming_accuracy(predictions, labels, weights=None, tuple. """ return metrics.accuracy( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_precision(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_precision(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the precision of the predictions with respect to the labels. @@ -512,13 +530,19 @@ def streaming_precision(predictions, labels, weights=None, tuple. """ return metrics.precision( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_recall(predictions, labels, weights=None, - metrics_collections=None, updates_collections=None, +def streaming_recall(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall of the predictions with respect to the labels. @@ -560,13 +584,242 @@ def streaming_recall(predictions, labels, weights=None, tuple. """ return metrics.recall( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) + + +def _true_negatives(labels, + predictions, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Sum the weights of true negatives. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. + + Returns: + value_tensor: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the error from a batch of data. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'true_negatives', + (predictions, labels, weights)): + + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + is_true_negative = math_ops.logical_and( + math_ops.equal(labels, False), math_ops.equal(predictions, False)) + return _count_condition(is_true_negative, weights, metrics_collections, + updates_collections) + + +def streaming_false_positive_rate(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false positive rate of predictions with respect to labels. + + The `false_positive_rate` function creates two local variables, + `false_positives` and `true_negatives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_positive_rate`, an idempotent operation that simply divides + `false_positives` by the sum of `false_positives` and `true_negatives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: Scalar float `Tensor` with the value of + `false_positives` divided by the sum of `false_positives` and + `true_negatives`. + update_op: `Operation` that increments `false_positives` and + `true_negatives` variables appropriately and whose value matches + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_positive_rate', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_p, false_positives_update_op = metrics.false_positives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + true_n, true_negatives_update_op = _true_negatives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + + def compute_fpr(fp, tn, name): + return array_ops.where( + math_ops.greater(fp + tn, 0), math_ops.div(fp, fp + tn), 0, name) + + fpr = compute_fpr(false_p, true_n, 'value') + update_op = compute_fpr(false_positives_update_op, true_negatives_update_op, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate(predictions, + labels, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the false negative rate of predictions with respect to labels. + + The `false_negative_rate` function creates two local variables, + `false_negatives` and `true_positives`, that are used to compute the + false positive rate. This value is ultimately returned as + `false_negative_rate`, an idempotent operation that simply divides + `false_negatives` by the sum of `false_negatives` and `true_positives`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_negative_rate`. `update_op` weights each prediction by the + corresponding value in `weights`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will + be cast to `bool`. + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: Scalar float `Tensor` with the value of + `false_negatives` divided by the sum of `false_negatives` and + `true_positives`. + update_op: `Operation` that increments `false_negatives` and + `true_positives` variables appropriately and whose value matches + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_negative_rate', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=math_ops.cast(predictions, dtype=dtypes.bool), + labels=math_ops.cast(labels, dtype=dtypes.bool), + weights=weights) + + false_n, false_negatives_update_op = metrics.false_negatives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + true_p, true_positives_update_op = metrics.true_positives( + labels, + predictions, + weights, + metrics_collections=None, + updates_collections=None, + name=None) + + def compute_fnr(fn, tp, name): + return array_ops.where( + math_ops.greater(fn + tp, 0), math_ops.div(fn, fn + tp), 0, name) + + fnr = compute_fnr(false_n, true_p, 'value') + update_op = compute_fnr(false_negatives_update_op, true_positives_update_op, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + return fnr, update_op -def _streaming_confusion_matrix_at_thresholds( - predictions, labels, thresholds, weights=None, includes=None): + +def _streaming_confusion_matrix_at_thresholds(predictions, + labels, + thresholds, + weights=None, + includes=None): """Computes true_positives, false_negatives, true_negatives, false_positives. This function creates up to four local variables, `true_positives`, @@ -618,7 +871,7 @@ def _streaming_confusion_matrix_at_thresholds( if include not in all_includes: raise ValueError('Invaild key: %s.' % include) - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) @@ -654,8 +907,8 @@ def _streaming_confusion_matrix_at_thresholds( if weights is not None: broadcast_weights = weights_broadcast_ops.broadcast_weights( math_ops.to_float(weights), predictions) - weights_tiled = array_ops.tile(array_ops.reshape( - broadcast_weights, [1, -1]), [num_thresholds, 1]) + weights_tiled = array_ops.tile( + array_ops.reshape(broadcast_weights, [1, -1]), [num_thresholds, 1]) thresh_tiled.get_shape().assert_is_compatible_with( weights_tiled.get_shape()) else: @@ -665,71 +918,87 @@ def _streaming_confusion_matrix_at_thresholds( update_ops = {} if 'tp' in includes: - true_positives = _create_local('true_positives', shape=[num_thresholds]) + true_positives = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='true_positives') is_true_positive = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_pos)) if weights_tiled is not None: is_true_positive *= weights_tiled - update_ops['tp'] = state_ops.assign_add( - true_positives, math_ops.reduce_sum(is_true_positive, 1)) + update_ops['tp'] = state_ops.assign_add(true_positives, + math_ops.reduce_sum( + is_true_positive, 1)) values['tp'] = true_positives if 'fn' in includes: - false_negatives = _create_local('false_negatives', shape=[num_thresholds]) + false_negatives = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='false_negatives') is_false_negative = math_ops.to_float( math_ops.logical_and(label_is_pos, pred_is_neg)) if weights_tiled is not None: is_false_negative *= weights_tiled - update_ops['fn'] = state_ops.assign_add( - false_negatives, math_ops.reduce_sum(is_false_negative, 1)) + update_ops['fn'] = state_ops.assign_add(false_negatives, + math_ops.reduce_sum( + is_false_negative, 1)) values['fn'] = false_negatives if 'tn' in includes: - true_negatives = _create_local('true_negatives', shape=[num_thresholds]) + true_negatives = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='true_negatives') is_true_negative = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_neg)) if weights_tiled is not None: is_true_negative *= weights_tiled - update_ops['tn'] = state_ops.assign_add( - true_negatives, math_ops.reduce_sum(is_true_negative, 1)) + update_ops['tn'] = state_ops.assign_add(true_negatives, + math_ops.reduce_sum( + is_true_negative, 1)) values['tn'] = true_negatives if 'fp' in includes: - false_positives = _create_local('false_positives', shape=[num_thresholds]) + false_positives = metrics_impl.metric_variable( + [num_thresholds], dtypes.float32, name='false_positives') is_false_positive = math_ops.to_float( math_ops.logical_and(label_is_neg, pred_is_pos)) if weights_tiled is not None: is_false_positive *= weights_tiled - update_ops['fp'] = state_ops.assign_add( - false_positives, math_ops.reduce_sum(is_false_positive, 1)) + update_ops['fp'] = state_ops.assign_add(false_positives, + math_ops.reduce_sum( + is_false_positive, 1)) values['fp'] = false_positives return values, update_ops -def streaming_true_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tp',)) return values['tp'], update_ops['tp'] -def streaming_false_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fn',)) return values['fn'], update_ops['fn'] -def streaming_false_positives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_false_positives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('fp',)) return values['fp'], update_ops['fp'] -def streaming_true_negatives_at_thresholds( - predictions, labels, thresholds, weights=None): +def streaming_true_negatives_at_thresholds(predictions, + labels, + thresholds, + weights=None): values, update_ops = _streaming_confusion_matrix_at_thresholds( predictions, labels, thresholds, weights=weights, includes=('tn',)) return values['tn'], update_ops['tn'] @@ -788,12 +1057,15 @@ def streaming_curve_points(labels=None, `weights` is not `None` and its shape doesn't match `predictions`, or if either `metrics_collections` or `updates_collections` are not a list or tuple. + + TODO(chizeng): Consider rewriting this method to make use of logic within the + streaming_precision_recall_at_equal_thresholds method (to improve run time). """ - with variable_scope.variable_scope(name, 'curve_points', (labels, predictions, - weights)): + with variable_scope.variable_scope(name, 'curve_points', + (labels, predictions, weights)): if curve != 'ROC' and curve != 'PR': raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) - kepsilon = 1e-7 # to account for floating point imprecisions + kepsilon = _EPSILON # to account for floating point imprecisions thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)] thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] @@ -830,10 +1102,16 @@ def streaming_curve_points(labels=None, return points, update_op - -def streaming_auc(predictions, labels, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, - curve='ROC', name=None): +@deprecated(None, "Please switch to tf.metrics.auc. Note that the order of " + "the inputs of labels and predictions have been switched.") +def streaming_auc(predictions, + labels, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + curve='ROC', + name=None): """Computes the approximate AUC via a Riemann sum. The `streaming_auc` function creates four local variables, `true_positives`, @@ -890,14 +1168,201 @@ def streaming_auc(predictions, labels, weights=None, num_thresholds=200, tuple. """ return metrics.auc( - predictions=predictions, labels=labels, weights=weights, - metrics_collections=metrics_collections, num_thresholds=num_thresholds, - curve=curve, updates_collections=updates_collections, name=name) + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + num_thresholds=num_thresholds, + curve=curve, + updates_collections=updates_collections, + name=name) + + +def streaming_precision_recall_at_equal_thresholds(predictions, + labels, + num_thresholds=None, + weights=None, + name=None, + use_locking=None): + """A helper method for creating metrics related to precision-recall curves. + These values are true positives, false negatives, true negatives, false + positives, precision, and recall. This function returns a data structure that + contains ops within it. + + Unlike _streaming_confusion_matrix_at_thresholds (which exhibits O(T * N) + space and run time), this op exhibits O(T + N) space and run time, where T is + the number of thresholds and N is the size of the predictions tensor. Hence, + it may be advantageous to use this function when `predictions` is big. + + For instance, prefer this method for per-pixel classification tasks, for which + the predictions tensor may be very large. + + Each number in `predictions`, a float in `[0, 1]`, is compared with its + corresponding label in `labels`, and counts as a single tp/fp/tn/fn value at + each threshold. This is then multiplied with `weights` which can be used to + reweight certain values, or more commonly used for masking values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A bool `Tensor` whose shape matches `predictions`. + num_thresholds: Optional; Number of thresholds, evenly distributed in + `[0, 1]`. Should be `>= 2`. Defaults to 201. Note that the number of bins + is 1 less than `num_thresholds`. Using an even `num_thresholds` value + instead of an odd one may yield unfriendly edges for bins. + weights: Optional; If provided, a `Tensor` that has the same dtype as, + and broadcastable to, `predictions`. This tensor is multplied by counts. + name: Optional; variable_scope name. If not provided, the string + 'precision_recall_at_equal_threshold' is used. + use_locking: Optional; If True, the op will be protected by a lock. + Otherwise, the behavior is undefined, but may exhibit less contention. + Defaults to True. -def streaming_specificity_at_sensitivity( - predictions, labels, sensitivity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): + Returns: + result: A named tuple (See PrecisionRecallData within the implementation of + this function) with properties that are variables of shape + `[num_thresholds]`. The names of the properties are tp, fp, tn, fn, + precision, recall, thresholds. + update_op: An op that accumulates values. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + `includes` contains invalid keys. + """ + # Disable the invalid-name checker so that we can capitalize the name. + # pylint: disable=invalid-name + PrecisionRecallData = collections_lib.namedtuple( + 'PrecisionRecallData', + ['tp', 'fp', 'tn', 'fn', 'precision', 'recall', 'thresholds']) + # pylint: enable=invalid-name + + if num_thresholds is None: + num_thresholds = 201 + + if weights is None: + weights = 1.0 + + if use_locking is None: + use_locking = True + + check_ops.assert_type(labels, dtypes.bool) + + dtype = predictions.dtype + with variable_scope.variable_scope(name, + 'precision_recall_at_equal_thresholds', + (labels, predictions, weights)): + # Make sure that predictions are within [0.0, 1.0]. + with ops.control_dependencies([ + check_ops.assert_greater_equal( + predictions, + math_ops.cast(0.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]'), + check_ops.assert_less_equal( + predictions, + math_ops.cast(1.0, dtype=predictions.dtype), + message='predictions must be in [0, 1]') + ]): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=predictions, + labels=labels, + weights=weights) + + predictions.get_shape().assert_is_compatible_with(labels.get_shape()) + + # We cast to float to ensure we have 0.0 or 1.0. + f_labels = math_ops.cast(labels, dtype) + + # Get weighted true/false labels. + true_labels = f_labels * weights + false_labels = (1.0 - f_labels) * weights + + # Flatten predictions and labels. + predictions = array_ops.reshape(predictions, [-1]) + true_labels = array_ops.reshape(true_labels, [-1]) + false_labels = array_ops.reshape(false_labels, [-1]) + + # To compute TP/FP/TN/FN, we are measuring a binary classifier + # C(t) = (predictions >= t) + # at each threshold 't'. So we have + # TP(t) = sum( C(t) * true_labels ) + # FP(t) = sum( C(t) * false_labels ) + # + # But, computing C(t) requires computation for each t. To make it fast, + # observe that C(t) is a cumulative integral, and so if we have + # thresholds = [t_0, ..., t_{n-1}]; t_0 < ... < t_{n-1} + # where n = num_thresholds, and if we can compute the bucket function + # B(i) = Sum( (predictions == t), t_i <= t < t{i+1} ) + # then we get + # C(t_i) = sum( B(j), j >= i ) + # which is the reversed cumulative sum in tf.cumsum(). + # + # We can compute B(i) efficiently by taking advantage of the fact that + # our thresholds are evenly distributed, in that + # width = 1.0 / (num_thresholds - 1) + # thresholds = [0.0, 1*width, 2*width, 3*width, ..., 1.0] + # Given a prediction value p, we can map it to its bucket by + # bucket_index(p) = floor( p * (num_thresholds - 1) ) + # so we can use tf.scatter_add() to update the buckets in one pass. + # + # This implementation exhibits a run time and space complexity of O(T + N), + # where T is the number of thresholds and N is the size of predictions. + # Metrics that rely on _streaming_confusion_matrix_at_thresholds instead + # exhibit a complexity of O(T * N). + + # Compute the bucket indices for each prediction value. + bucket_indices = math_ops.cast( + math_ops.floor(predictions * (num_thresholds - 1)), dtypes.int32) + + with ops.name_scope('variables'): + tp_buckets_v = metrics_impl.metric_variable( + [num_thresholds], dtype, name='tp_buckets') + fp_buckets_v = metrics_impl.metric_variable( + [num_thresholds], dtype, name='fp_buckets') + + with ops.name_scope('update_op'): + update_tp = state_ops.scatter_add( + tp_buckets_v, bucket_indices, true_labels, use_locking=use_locking) + update_fp = state_ops.scatter_add( + fp_buckets_v, bucket_indices, false_labels, use_locking=use_locking) + + # Set up the cumulative sums to compute the actual metrics. + tp = math_ops.cumsum(tp_buckets_v, reverse=True, name='tp') + fp = math_ops.cumsum(fp_buckets_v, reverse=True, name='fp') + # fn = sum(true_labels) - tp + # = sum(tp_buckets) - tp + # = tp[0] - tp + # Similarly, + # tn = fp[0] - fp + tn = fp[0] - fp + fn = tp[0] - tp + + # We use a minimum to prevent division by 0. + epsilon = 1e-7 + precision = tp / math_ops.maximum(epsilon, tp + fp) + recall = tp / math_ops.maximum(epsilon, tp + fn) + + result = PrecisionRecallData( + tp=tp, + fp=fp, + tn=tn, + fn=fn, + precision=precision, + recall=recall, + thresholds=math_ops.lin_space(0.0, 1.0, num_thresholds)) + update_op = control_flow_ops.group(update_tp, update_fp) + return result, update_op + + +def streaming_specificity_at_sensitivity(predictions, + labels, + sensitivity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the specificity at a given sensitivity. The `streaming_specificity_at_sensitivity` function creates four local @@ -947,15 +1412,24 @@ def streaming_specificity_at_sensitivity( or `updates_collections` are not a list or tuple. """ return metrics.specificity_at_sensitivity( - sensitivity=sensitivity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + sensitivity=sensitivity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_sensitivity_at_specificity( - predictions, labels, specificity, weights=None, num_thresholds=200, - metrics_collections=None, updates_collections=None, name=None): +def streaming_sensitivity_at_specificity(predictions, + labels, + specificity, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): """Computes the sensitivity at a given specificity. The `streaming_sensitivity_at_specificity` function creates four local @@ -1005,16 +1479,25 @@ def streaming_sensitivity_at_specificity( or `updates_collections` are not a list or tuple. """ return metrics.sensitivity_at_specificity( - specificity=specificity, num_thresholds=num_thresholds, - predictions=predictions, labels=labels, weights=weights, + specificity=specificity, + num_thresholds=num_thresholds, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) - + updates_collections=updates_collections, + name=name) -def streaming_precision_at_thresholds(predictions, labels, thresholds, +@deprecated( + None, "Please switch to tf.metrics.precision_at_thresholds. Note that the " + "order of of the inputs of labels and predictions have been switched.") +def streaming_precision_at_thresholds(predictions, + labels, + thresholds, weights=None, metrics_collections=None, - updates_collections=None, name=None): + updates_collections=None, + name=None): """Computes precision values for different `thresholds` on `predictions`. The `streaming_precision_at_thresholds` function creates four local variables, @@ -1059,14 +1542,23 @@ def streaming_precision_at_thresholds(predictions, labels, thresholds, """ return metrics.precision_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) - + updates_collections=updates_collections, + name=name) -def streaming_recall_at_thresholds(predictions, labels, thresholds, - weights=None, metrics_collections=None, - updates_collections=None, name=None): +@deprecated( + None, "Please switch to tf.metrics.recall_at_thresholds. Note that the " + "order of of the inputs of labels and predictions have been switched.") +def streaming_recall_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): """Computes various recall values for different `thresholds` on `predictions`. The `streaming_recall_at_thresholds` function creates four local variables, @@ -1109,9 +1601,154 @@ def streaming_recall_at_thresholds(predictions, labels, thresholds, """ return metrics.recall_at_thresholds( thresholds=thresholds, - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) + + +def streaming_false_positive_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes various fpr values for different `thresholds` on `predictions`. + + The `streaming_false_positive_rate_at_thresholds` function creates two + local variables, `false_positives`, `true_negatives`, for various values of + thresholds. `false_positive_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `False` values in `labels` + (`false_positives[i] / (false_positives[i] + true_negatives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_positive_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_positive_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_positives` and + `true_negatives` variables that are used in the computation of + `false_positive_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_positive_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fp', 'tn')) + + # Avoid division by zero. + epsilon = _EPSILON + + def compute_fpr(fp, tn, name): + return math_ops.div(fp, epsilon + fp + tn, name='fpr_' + name) + + fpr = compute_fpr(values['fp'], values['tn'], 'value') + update_op = compute_fpr(update_ops['fp'], update_ops['tn'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fpr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fpr, update_op + + +def streaming_false_negative_rate_at_thresholds(predictions, + labels, + thresholds, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes various fnr values for different `thresholds` on `predictions`. + + The `streaming_false_negative_rate_at_thresholds` function creates two + local variables, `false_negatives`, `true_positives`, for various values of + thresholds. `false_negative_rate[i]` is defined as the total weight + of values in `predictions` above `thresholds[i]` whose corresponding entry in + `labels` is `False`, divided by the total weight of `True` values in `labels` + (`false_negatives[i] / (false_negatives[i] + true_positives[i])`). + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `false_positive_rate`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + labels: A `bool` `Tensor` whose shape matches `predictions`. + thresholds: A python list or tuple of float thresholds in `[0, 1]`. + weights: `Tensor` whose rank is either 0, or the same rank as `labels`, and + must be broadcastable to `labels` (i.e., all dimensions must be either + `1`, or the same as the corresponding `labels` dimension). + metrics_collections: An optional list of collections that + `false_negative_rate` should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + false_negative_rate: A float `Tensor` of shape `[len(thresholds)]`. + update_op: An operation that increments the `false_negatives` and + `true_positives` variables that are used in the computation of + `false_negative_rate`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, or if + `weights` is not `None` and its shape doesn't match `predictions`, or if + either `metrics_collections` or `updates_collections` are not a list or + tuple. + """ + with variable_scope.variable_scope(name, 'false_negative_rate_at_thresholds', + (predictions, labels, weights)): + values, update_ops = _streaming_confusion_matrix_at_thresholds( + predictions, labels, thresholds, weights, includes=('fn', 'tp')) + + # Avoid division by zero. + epsilon = _EPSILON + + def compute_fnr(fn, tp, name): + return math_ops.div(fn, epsilon + fn + tp, name='fnr_' + name) + + fnr = compute_fnr(values['fn'], values['tp'], 'value') + update_op = compute_fnr(update_ops['fn'], update_ops['tp'], 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, fnr) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return fnr, update_op def _at_k_name(name, k=None, class_id=None): @@ -1124,10 +1761,14 @@ def _at_k_name(name, k=None, class_id=None): return name -@deprecated('2016-11-08', 'Please use `streaming_sparse_recall_at_k`, ' - 'and reshape labels from [batch_size] to [batch_size, 1].') -def streaming_recall_at_k(predictions, labels, k, weights=None, - metrics_collections=None, updates_collections=None, +@deprecated("2016-11-08", "Please use `streaming_sparse_recall_at_k`, " + "and reshape labels from [batch_size] to [batch_size, 1].") +def streaming_recall_at_k(predictions, + labels, + k, + weights=None, + metrics_collections=None, + updates_collections=None, name=None): """Computes the recall@k of the predictions with respect to dense labels. @@ -1173,11 +1814,8 @@ def streaming_recall_at_k(predictions, labels, k, weights=None, tuple. """ in_top_k = math_ops.to_float(nn.in_top_k(predictions, labels, k)) - return streaming_mean(in_top_k, - weights, - metrics_collections, - updates_collections, - name or _at_k_name('recall', k)) + return streaming_mean(in_top_k, weights, metrics_collections, + updates_collections, name or _at_k_name('recall', k)) # TODO(ptucker): Validate range of values in labels? @@ -1256,10 +1894,14 @@ def streaming_sparse_recall_at_k(predictions, are not a list or tuple. """ return metrics.recall_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1341,10 +1983,14 @@ def streaming_sparse_precision_at_k(predictions, are not a list or tuple. """ return metrics.sparse_precision_at_k( - k=k, class_id=class_id, - predictions=predictions, labels=labels, weights=weights, + k=k, + class_id=class_id, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) # TODO(ptucker): Validate range of values in labels? @@ -1423,10 +2069,9 @@ def streaming_sparse_precision_at_top_k(top_k_predictions, ValueError: If `top_k_predictions` has rank < 2. """ default_name = _at_k_name('precision', class_id=class_id) - with ops.name_scope( - name, default_name, - (top_k_predictions, labels, weights)) as name_scope: - return metrics_impl._sparse_precision_at_top_k( # pylint: disable=protected-access + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: + return metrics_impl.precision_at_top_k( labels=labels, predictions_idx=top_k_predictions, class_id=class_id, @@ -1505,9 +2150,9 @@ def sparse_recall_at_top_k(labels, are not a list or tuple. """ default_name = _at_k_name('recall', class_id=class_id) - with ops.name_scope(name, default_name, (top_k_predictions, labels, - weights)) as name_scope: - return metrics_impl._sparse_recall_at_top_k( # pylint: disable=protected-access + with ops.name_scope(name, default_name, + (top_k_predictions, labels, weights)) as name_scope: + return metrics_impl.recall_at_top_k( labels=labels, predictions_idx=top_k_predictions, class_id=class_id, @@ -1517,6 +2162,109 @@ def sparse_recall_at_top_k(labels, name=name_scope) +def _compute_recall_at_precision(tp, fp, fn, precision, name): + """Helper function to compute recall at a given `precision`. + + Args: + tp: The number of true positives. + fp: The number of false positives. + fn: The number of false negatives. + precision: The precision for which the recall will be calculated. + name: An optional variable_scope name. + + Returns: + The recall at a the given `precision`. + """ + precisions = math_ops.div(tp, tp + fp + _EPSILON) + tf_index = math_ops.argmin( + math_ops.abs(precisions - precision), 0, output_type=dtypes.int32) + + # Now, we have the implicit threshold, so compute the recall: + return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + _EPSILON, + name) + + +def recall_at_precision(labels, + predictions, + precision, + weights=None, + num_thresholds=200, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes `recall` at `precision`. + + The `recall_at_precision` function creates four local variables, + `tp` (true positives), `fp` (false positives) and `fn` (false negatives) + that are used to compute the `recall` at the given `precision` value. The + threshold for the given `precision` value is computed and used to evaluate the + corresponding `recall`. + + For estimation of the metric over a stream of data, the function creates an + `update_op` operation that updates these variables and returns the + `recall`. `update_op` increments the `tp`, `fp` and `fn` counts with the + weight of each case found in the `predictions` and `labels`. + + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. + + Args: + labels: The ground truth values, a `Tensor` whose dimensions must match + `predictions`. Will be cast to `bool`. + predictions: A floating point `Tensor` of arbitrary shape and whose values + are in the range `[0, 1]`. + precision: A scalar value in range `[0, 1]`. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions must + be either `1`, or the same as the corresponding `labels` dimension). + num_thresholds: The number of thresholds to use for matching the given + `precision`. + metrics_collections: An optional list of collections that `recall` + should be added to. + updates_collections: An optional list of collections that `update_op` should + be added to. + name: An optional variable_scope name. + + Returns: + recall: A scalar `Tensor` representing the recall at the given + `precision` value. + update_op: An operation that increments the `tp`, `fp` and `fn` + variables appropriately and whose value matches `recall`. + + Raises: + ValueError: If `predictions` and `labels` have mismatched shapes, if + `weights` is not `None` and its shape doesn't match `predictions`, or if + `precision` is not between 0 and 1, or if either `metrics_collections` + or `updates_collections` are not a list or tuple. + + """ + if not 0 <= precision <= 1: + raise ValueError('`precision` must be in the range [0, 1].') + + with variable_scope.variable_scope(name, 'recall_at_precision', + (predictions, labels, weights)): + thresholds = [ + i * 1.0 / (num_thresholds - 1) for i in range(1, num_thresholds - 1) + ] + thresholds = [0.0 - _EPSILON] + thresholds + [1.0 + _EPSILON] + + values, update_ops = _streaming_confusion_matrix_at_thresholds( + labels, predictions, thresholds, weights) + + recall = _compute_recall_at_precision(values['tp'], values['fp'], + values['fn'], precision, 'value') + update_op = _compute_recall_at_precision(update_ops['tp'], update_ops['fp'], + update_ops['fn'], precision, + 'update_op') + + if metrics_collections: + ops.add_to_collections(metrics_collections, recall) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return recall, update_op + + def streaming_sparse_average_precision_at_k(predictions, labels, k, @@ -1576,9 +2324,13 @@ def streaming_sparse_average_precision_at_k(predictions, value matches `metric`. """ return metrics.sparse_average_precision_at_k( - k=k, predictions=predictions, labels=labels, weights=weights, + k=k, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_sparse_average_precision_at_top_k(top_k_predictions, @@ -1643,8 +2395,10 @@ def streaming_sparse_average_precision_at_top_k(top_k_predictions, updates_collections=updates_collections, name=name) - -def streaming_mean_absolute_error(predictions, labels, weights=None, +@deprecated(None, "Please switch to tf.metrics.mean.") +def streaming_mean_absolute_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1692,12 +2446,18 @@ def streaming_mean_absolute_error(predictions, labels, weights=None, tuple. """ return metrics.mean_absolute_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, +def streaming_mean_relative_error(predictions, + labels, + normalizer, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1746,12 +2506,18 @@ def streaming_mean_relative_error(predictions, labels, normalizer, weights=None, tuple. """ return metrics.mean_relative_error( - normalizer=normalizer, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + normalizer=normalizer, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) -def streaming_mean_squared_error(predictions, labels, weights=None, +def streaming_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1799,12 +2565,17 @@ def streaming_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) -def streaming_root_mean_squared_error(predictions, labels, weights=None, +def streaming_root_mean_squared_error(predictions, + labels, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -1852,9 +2623,12 @@ def streaming_root_mean_squared_error(predictions, labels, weights=None, tuple. """ return metrics.root_mean_squared_error( - predictions=predictions, labels=labels, weights=weights, + predictions=predictions, + labels=labels, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_covariance(predictions, @@ -1910,15 +2684,18 @@ def streaming_covariance(predictions, ValueError: If labels and predictions are of different sizes or if either `metrics_collections` or `updates_collections` are not a list or tuple. """ - with variable_scope.variable_scope( - name, 'covariance', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + with variable_scope.variable_scope(name, 'covariance', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - count = _create_local('count', []) - mean_prediction = _create_local('mean_prediction', []) - mean_label = _create_local('mean_label', []) - comoment = _create_local('comoment', []) # C_A in update equation + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') + mean_prediction = metrics_impl.metric_variable( + [], dtypes.float32, name='mean_prediction') + mean_label = metrics_impl.metric_variable( + [], dtypes.float32, name='mean_label') + comoment = metrics_impl.metric_variable( # C_A in update equation + [], dtypes.float32, name='comoment') if weights is None: batch_count = math_ops.to_float(array_ops.size(labels)) # n_B in eqn @@ -1930,7 +2707,7 @@ def streaming_covariance(predictions, weighted_predictions = math_ops.multiply(predictions, weights) weighted_labels = math_ops.multiply(labels, weights) - update_count = state_ops.assign_add(count, batch_count) # n_AB in eqn + update_count = state_ops.assign_add(count_, batch_count) # n_AB in eqn prev_count = update_count - batch_count # n_A in update equation # We update the means by Delta=Error*BatchCount/(BatchCount+PrevCount) @@ -1955,34 +2732,34 @@ def streaming_covariance(predictions, # prev_mean_label is E[y_A] in the update equation prev_mean_label = update_mean_label - delta_mean_label - unweighted_batch_coresiduals = ( - (predictions - batch_mean_prediction) * (labels - batch_mean_label)) + unweighted_batch_coresiduals = ((predictions - batch_mean_prediction) * + (labels - batch_mean_label)) # batch_comoment is C_B in the update equation if weights is None: batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals) else: - batch_comoment = math_ops.reduce_sum(unweighted_batch_coresiduals * - weights) + batch_comoment = math_ops.reduce_sum( + unweighted_batch_coresiduals * weights) # View delta_comoment as = C_AB - C_A in the update equation above. # Since C_A is stored in a var, by how much do we need to increment that var # to make the var = C_AB? - delta_comoment = (batch_comoment + - (prev_mean_prediction - batch_mean_prediction) * - (prev_mean_label - batch_mean_label) * - (prev_count * batch_count / update_count)) + delta_comoment = ( + batch_comoment + (prev_mean_prediction - batch_mean_prediction) * + (prev_mean_label - batch_mean_label) * + (prev_count * batch_count / update_count)) update_comoment = state_ops.assign_add(comoment, delta_comoment) covariance = array_ops.where( - math_ops.less_equal(count, 1.), + math_ops.less_equal(count_, 1.), float('nan'), - math_ops.truediv(comoment, count - 1), + math_ops.truediv(comoment, count_ - 1), name='covariance') with ops.control_dependencies([update_comoment]): update_op = array_ops.where( - math_ops.less_equal(count, 1.), + math_ops.less_equal(count_, 1.), float('nan'), - math_ops.truediv(comoment, count - 1), + math_ops.truediv(comoment, count_ - 1), name='update_op') if metrics_collections: @@ -2044,9 +2821,9 @@ def streaming_pearson_correlation(predictions, `weights` is the wrong size, or if either `metrics_collections` or `updates_collections` are not a `list` or `tuple`. """ - with variable_scope.variable_scope( - name, 'pearson_r', (predictions, labels, weights)): - predictions, labels, weights = _remove_squeezable_dimensions( + with variable_scope.variable_scope(name, 'pearson_r', + (predictions, labels, weights)): + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) # Broadcast weights here to avoid duplicate broadcasting in each call to @@ -2062,13 +2839,14 @@ def streaming_pearson_correlation(predictions, pearson_r = math_ops.truediv( cov, - math_ops.multiply(math_ops.sqrt(var_predictions), - math_ops.sqrt(var_labels)), + math_ops.multiply( + math_ops.sqrt(var_predictions), math_ops.sqrt(var_labels)), name='pearson_r') update_op = math_ops.truediv( update_cov, - math_ops.multiply(math_ops.sqrt(update_var_predictions), - math_ops.sqrt(update_var_labels)), + math_ops.multiply( + math_ops.sqrt(update_var_predictions), + math_ops.sqrt(update_var_labels)), name='update_op') if metrics_collections: @@ -2082,7 +2860,10 @@ def streaming_pearson_correlation(predictions, # TODO(nsilberman): add a 'normalized' flag so that the user can request # normalization if the inputs are not normalized. -def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, +def streaming_mean_cosine_distance(predictions, + labels, + dim, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2124,16 +2905,15 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, either `metrics_collections` or `updates_collections` are not a list or tuple. """ - predictions, labels, weights = _remove_squeezable_dimensions( + predictions, labels, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access predictions, labels, weights) predictions.get_shape().assert_is_compatible_with(labels.get_shape()) radial_diffs = math_ops.multiply(predictions, labels) - radial_diffs = math_ops.reduce_sum(radial_diffs, - reduction_indices=[dim,], - keep_dims=True) - mean_distance, update_op = streaming_mean(radial_diffs, weights, - None, - None, + radial_diffs = math_ops.reduce_sum( + radial_diffs, reduction_indices=[ + dim, + ], keep_dims=True) + mean_distance, update_op = streaming_mean(radial_diffs, weights, None, None, name or 'mean_cosine_distance') mean_distance = math_ops.subtract(1.0, mean_distance) update_op = math_ops.subtract(1.0, update_op) @@ -2147,7 +2927,9 @@ def streaming_mean_cosine_distance(predictions, labels, dim, weights=None, return mean_distance, update_op -def streaming_percentage_less(values, threshold, weights=None, +def streaming_percentage_less(values, + threshold, + weights=None, metrics_collections=None, updates_collections=None, name=None): @@ -2187,9 +2969,12 @@ def streaming_percentage_less(values, threshold, weights=None, or tuple. """ return metrics.percentage_below( - values=values, threshold=threshold, weights=weights, + values=values, + threshold=threshold, + weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + updates_collections=updates_collections, + name=name) def streaming_mean_iou(predictions, @@ -2241,9 +3026,13 @@ def streaming_mean_iou(predictions, tuple. """ return metrics.mean_iou( - num_classes=num_classes, predictions=predictions, labels=labels, - weights=weights, metrics_collections=metrics_collections, - updates_collections=updates_collections, name=name) + num_classes=num_classes, + predictions=predictions, + labels=labels, + weights=weights, + metrics_collections=metrics_collections, + updates_collections=updates_collections, + name=name) def _next_array_size(required_size, growth_factor=1.5): @@ -2258,9 +3047,9 @@ def _next_array_size(required_size, growth_factor=1.5): tf.Tensor with dtype=int32 giving the next array size. """ exponent = math_ops.ceil( - math_ops.log(math_ops.cast(required_size, dtypes.float32)) - / math_ops.log(math_ops.cast(growth_factor, dtypes.float32))) - return math_ops.cast(math_ops.ceil(growth_factor ** exponent), dtypes.int32) + math_ops.log(math_ops.cast(required_size, dtypes.float32)) / math_ops.log( + math_ops.cast(growth_factor, dtypes.float32))) + return math_ops.cast(math_ops.ceil(growth_factor**exponent), dtypes.int32) def streaming_concat(values, @@ -2317,8 +3106,7 @@ def streaming_concat(values, if not 0 <= axis < ndim: raise ValueError('axis = %r not in [0, %r)' % (axis, ndim)) - fixed_shape = [dim.value for n, dim in enumerate(values_shape) - if n != axis] + fixed_shape = [dim.value for n, dim in enumerate(values_shape) if n != axis] if any(value is None for value in fixed_shape): raise ValueError('all dimensions of `values` other than the dimension to ' 'concatenate along must have statically known size') @@ -2327,9 +3115,9 @@ def streaming_concat(values, # applied to contiguous slices init_size = 0 if max_size is None else max_size init_shape = [init_size] + fixed_shape - array = _create_local( - 'array', shape=init_shape, validate_shape=False, dtype=values.dtype) - size = _create_local('size', shape=[], dtype=dtypes.int32) + array = metrics_impl.metric_variable( + init_shape, values.dtype, validate_shape=False, name='array') + size = metrics_impl.metric_variable([], dtypes.int32, name='size') perm = [0 if n == axis else n + 1 if n < axis else n for n in range(ndim)] valid_array = array[:size] @@ -2427,60 +3215,82 @@ def aggregate_metric_map(names_to_tuples): return dict(zip(metric_names, value_ops)), dict(zip(metric_names, update_ops)) -def _remove_squeezable_dimensions(predictions, labels, weights): - """Squeeze last dim if needed. +def count(values, + weights=None, + metrics_collections=None, + updates_collections=None, + name=None): + """Computes the number of examples, or sum of `weights`. - Squeezes `predictions` and `labels` if their rank differs by 1. - Squeezes `weights` if its rank is 1 more than the new rank of `predictions` + When evaluating some metric (e.g. mean) on one or more subsets of the data, + this auxiliary metric is useful for keeping track of how many examples there + are in each subset. - This will use static shape if available. Otherwise, it will add graph - operations, which could result in a performance hit. + If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. Args: - predictions: Predicted values, a `Tensor` of arbitrary dimensions. - labels: Label values, a `Tensor` whose dimensions match `predictions`. - weights: Optional weight `Tensor`. It will be squeezed if its rank is 1 - more than the new rank of `predictions` + values: A `Tensor` of arbitrary dimensions. Only it's shape is used. + weights: Optional `Tensor` whose rank is either 0, or the same rank as + `labels`, and must be broadcastable to `labels` (i.e., all dimensions + must be either `1`, or the same as the corresponding `labels` + dimension). + metrics_collections: An optional list of collections that the metric + value variable should be added to. + updates_collections: An optional list of collections that the metric update + ops should be added to. + name: An optional variable_scope name. Returns: - Tuple of `predictions`, `labels` and `weights`, possibly with the last - dimension squeezed. + count: A `Tensor` representing the current value of the metric. + update_op: An operation that accumulates the metric from a batch of data. + + Raises: + ValueError: If `weights` is not `None` and its shape doesn't match `values`, + or if either `metrics_collections` or `updates_collections` are not a list + or tuple. """ - labels, predictions = confusion_matrix.remove_squeezable_dimensions( - labels, predictions) - predictions.get_shape().assert_is_compatible_with(labels.get_shape()) - if weights is not None: - weights = ops.convert_to_tensor(weights) - predictions_shape = predictions.get_shape() - predictions_rank = predictions_shape.ndims - weights_shape = weights.get_shape() - weights_rank = weights_shape.ndims - - if (predictions_rank is not None) and (weights_rank is not None): - # Use static rank. - if weights_rank - predictions_rank == 1: - weights = array_ops.squeeze(weights, [-1]) - elif (weights_rank is None) or ( - weights_shape.dims[-1].is_compatible_with(1)): - # Use dynamic rank - weights = control_flow_ops.cond( - math_ops.equal(array_ops.rank(weights), - math_ops.add(array_ops.rank(predictions), 1)), - lambda: array_ops.squeeze(weights, [-1]), - lambda: weights) - return predictions, labels, weights + with variable_scope.variable_scope(name, 'count', (values, weights)): + count_ = metrics_impl.metric_variable([], dtypes.float32, name='count') + + if weights is None: + num_values = math_ops.to_float(array_ops.size(values)) + else: + _, _, weights = metrics_impl._remove_squeezable_dimensions( # pylint: disable=protected-access + predictions=values, + labels=None, + weights=weights) + weights = weights_broadcast_ops.broadcast_weights( + math_ops.to_float(weights), values) + num_values = math_ops.reduce_sum(weights) + + with ops.control_dependencies([values]): + update_op = state_ops.assign_add(count_, num_values) + + if metrics_collections: + ops.add_to_collections(metrics_collections, count_) + + if updates_collections: + ops.add_to_collections(updates_collections, update_op) + + return count_, update_op __all__ = [ 'aggregate_metric_map', 'aggregate_metrics', + 'count', + 'recall_at_precision', 'sparse_recall_at_top_k', 'streaming_accuracy', 'streaming_auc', 'streaming_curve_points', + 'streaming_false_negative_rate', + 'streaming_false_negative_rate_at_thresholds', 'streaming_false_negatives', 'streaming_false_negatives_at_thresholds', + 'streaming_false_positive_rate', + 'streaming_false_positive_rate_at_thresholds', 'streaming_false_positives', 'streaming_false_positives_at_thresholds', 'streaming_mean', diff --git a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py index 9b959b43a9db8baac5b37524e81bfbb11d6ad868..6a8e58b4daf9c49b9033b6e8bab3656bfc68b989 100644 --- a/tensorflow/contrib/metrics/python/ops/metric_ops_test.py +++ b/tensorflow/contrib/metrics/python/ops/metric_ops_test.py @@ -149,9 +149,12 @@ def _assert_nan(test_case, actual): test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual) -def _assert_local_variables(test_case, expected): +def _assert_metric_variables(test_case, expected): test_case.assertEquals( set(expected), set(v.name for v in variables.local_variables())) + test_case.assertEquals( + set(expected), + set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) class StreamingMeanTest(test.TestCase): @@ -161,7 +164,7 @@ class StreamingMeanTest(test.TestCase): def testVars(self): metrics.streaming_mean(array_ops.ones([4, 3])) - _assert_local_variables(self, ('mean/count:0', 'mean/total:0')) + _assert_metric_variables(self, ('mean/count:0', 'mean/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -319,8 +322,8 @@ class StreamingMeanTensorTest(test.TestCase): def testVars(self): metrics.streaming_mean_tensor(array_ops.ones([4, 3])) - _assert_local_variables(self, ('mean/total_tensor:0', - 'mean/count_tensor:0')) + _assert_metric_variables(self, + ('mean/total_tensor:0', 'mean/count_tensor:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -485,8 +488,8 @@ class StreamingAccuracyTest(test.TestCase): predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), name='my_accuracy') - _assert_local_variables(self, ('my_accuracy/count:0', - 'my_accuracy/total:0')) + _assert_metric_variables(self, + ('my_accuracy/count:0', 'my_accuracy/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -660,7 +663,7 @@ class StreamingTruePositivesTest(test.TestCase): def testVars(self): metrics.streaming_true_positives((0, 1, 0), (0, 1, 1)) - _assert_local_variables(self, ('true_positives/count:0',)) + _assert_metric_variables(self, ('true_positives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: @@ -716,7 +719,7 @@ class StreamingFalseNegativesTest(test.TestCase): def testVars(self): metrics.streaming_false_negatives((0, 1, 0), (0, 1, 1)) - _assert_local_variables(self, ('false_negatives/count:0',)) + _assert_metric_variables(self, ('false_negatives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: @@ -772,7 +775,7 @@ class StreamingFalsePositivesTest(test.TestCase): def testVars(self): metrics.streaming_false_positives((0, 1, 0), (0, 1, 1)) - _assert_local_variables(self, ('false_positives/count:0',)) + _assert_metric_variables(self, ('false_positives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: @@ -832,7 +835,7 @@ class StreamingTrueNegativesTest(test.TestCase): def testVars(self): metrics.streaming_true_negatives((0, 1, 0), (0, 1, 1)) - _assert_local_variables(self, ('true_negatives/count:0',)) + _assert_metric_variables(self, ('true_negatives/count:0',)) def testUnweighted(self): for expand_predictions in [True, False]: @@ -888,7 +891,7 @@ class StreamingTruePositivesAtThresholdsTest(test.TestCase): def testVars(self): metrics.streaming_true_positives_at_thresholds( (0.0, 1.0, 0.0), (0, 1, 1), thresholds=(0.15, 0.5, 0.85)) - _assert_local_variables(self, ('true_positives:0',)) + _assert_metric_variables(self, ('true_positives:0',)) def testUnweighted(self): predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), @@ -935,7 +938,7 @@ class StreamingFalseNegativesAtThresholdsTest(test.TestCase): 0.15, 0.5, 0.85,)) - _assert_local_variables(self, ('false_negatives:0',)) + _assert_metric_variables(self, ('false_negatives:0',)) def testUnweighted(self): predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), @@ -982,7 +985,7 @@ class StreamingFalsePositivesAtThresholdsTest(test.TestCase): def testVars(self): metrics.streaming_false_positives_at_thresholds( (0.0, 1.0, 0.0), (0, 1, 1), thresholds=(0.15, 0.5, 0.85)) - _assert_local_variables(self, ('false_positives:0',)) + _assert_metric_variables(self, ('false_positives:0',)) def testUnweighted(self): predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), @@ -1031,7 +1034,7 @@ class StreamingTrueNegativesAtThresholdsTest(test.TestCase): def testVars(self): metrics.streaming_true_negatives_at_thresholds( (0.0, 1.0, 0.0), (0, 1, 1), thresholds=(0.15, 0.5, 0.85)) - _assert_local_variables(self, ('true_negatives:0',)) + _assert_metric_variables(self, ('true_negatives:0',)) def testUnweighted(self): predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), @@ -1078,8 +1081,8 @@ class StreamingPrecisionTest(test.TestCase): def testVars(self): metrics.streaming_precision( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, ('precision/false_positives/count:0', - 'precision/true_positives/count:0')) + _assert_metric_variables(self, ('precision/false_positives/count:0', + 'precision/true_positives/count:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -1101,7 +1104,7 @@ class StreamingPrecisionTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) precision, update_op = metrics.streaming_precision(predictions, labels) with self.test_session() as sess: @@ -1242,8 +1245,9 @@ class StreamingRecallTest(test.TestCase): def testVars(self): metrics.streaming_recall( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, ('recall/false_negatives/count:0', - 'recall/true_positives/count:0')) + _assert_metric_variables( + self, + ('recall/false_negatives/count:0', 'recall/true_positives/count:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -1265,7 +1269,7 @@ class StreamingRecallTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) recall, update_op = metrics.streaming_recall(predictions, labels) with self.test_session() as sess: @@ -1355,6 +1359,262 @@ class StreamingRecallTest(test.TestCase): self.assertEqual(0, recall.eval()) +class StreamingFPRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_metric_variables(self, + ('false_positive_rate/false_positives/count:0', + 'false_positive_rate/true_negatives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertEqual(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fpr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 2.0 + 5.0 + weighted_f = (2.0 + 2.0) + (5.0 + 5.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fp = 1.0 + 3.0 + weighted_f = (1.0 + 4.0) + (2.0 + 3.0) + expected_fpr = weighted_fp / weighted_f + self.assertAlmostEqual(expected_fpr, update_op.eval()) + self.assertAlmostEqual(expected_fpr, fpr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fpr.eval()) + + def testZeroFalsePositivesAndTrueNegativesGivesZeroFPR(self): + predictions = array_ops.ones((1, 4)) + labels = array_ops.ones((1, 4)) + fpr, update_op = metrics.streaming_false_positive_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fpr.eval()) + + +class StreamingFNRTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) + _assert_metric_variables(self, + ('false_negative_rate/false_negatives/count:0', + 'false_negative_rate/true_positives/count:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_negative_rate( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertEqual(initial_fnr, fnr.eval()) + + def testAllCorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + def testSomeCorrect(self): + predictions = constant_op.constant([1, 0, 1, 0], shape=(1, 4)) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.5, update_op.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) + + def testWeighted1d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[2], [5]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 5.0 + weighted_t = (2.0 + 2.0) + (5.0 + 5.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testWeighted2d(self): + predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) + labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) + weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels, weights=weights) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + weighted_fn = 2.0 + 4.0 + weighted_t = (2.0 + 3.0) + (1.0 + 4.0) + expected_fnr = weighted_fn / weighted_t + self.assertAlmostEqual(expected_fnr, update_op.eval()) + self.assertAlmostEqual(expected_fnr, fnr.eval()) + + def testAllIncorrect(self): + np_inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(np_inputs) + labels = constant_op.constant(1 - np_inputs) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(1, fnr.eval()) + + def testZeroFalseNegativesAndTruePositivesGivesZeroFNR(self): + predictions = array_ops.zeros((1, 4)) + labels = array_ops.zeros((1, 4)) + fnr, update_op = metrics.streaming_false_negative_rate( + predictions, labels) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + self.assertEqual(0, fnr.eval()) + + class StreamingCurvePointsTest(test.TestCase): def setUp(self): @@ -1364,7 +1624,7 @@ class StreamingCurvePointsTest(test.TestCase): def testVars(self): metric_ops.streaming_curve_points( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables( + _assert_metric_variables( self, ('curve_points/true_positives:0', 'curve_points/false_negatives:0', 'curve_points/false_positives:0', 'curve_points/true_negatives:0')) @@ -1457,9 +1717,9 @@ class StreamingAUCTest(test.TestCase): def testVars(self): metrics.streaming_auc( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, - ('auc/true_positives:0', 'auc/false_negatives:0', - 'auc/false_positives:0', 'auc/true_negatives:0')) + _assert_metric_variables(self, + ('auc/true_positives:0', 'auc/false_negatives:0', + 'auc/false_positives:0', 'auc/true_negatives:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -1481,7 +1741,7 @@ class StreamingAUCTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) auc, update_op = metrics.streaming_auc(predictions, labels) with self.test_session() as sess: @@ -1714,6 +1974,167 @@ class StreamingAUCTest(test.TestCase): self.assertAlmostEqual(expected_auc, auc.eval(), 2) +class StreamingPrecisionRecallAtEqualThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def _testResultsEqual(self, expected_dict, gotten_result): + """Tests that 2 results (dicts) represent the same data. + + Args: + expected_dict: A dictionary with keys that are the names of properties + of PrecisionRecallData and whose values are lists of floats. + gotten_result: A PrecisionRecallData object. + """ + gotten_dict = {k: t.eval() for k, t in gotten_result._asdict().items()} + self.assertItemsEqual( + list(expected_dict.keys()), list(gotten_dict.keys())) + + for key, expected_values in expected_dict.items(): + self.assertAllClose(expected_values, gotten_dict[key]) + + def _testCase(self, predictions, labels, expected_result, weights=None): + """Performs a test given a certain scenario of labels, predictions, weights. + + Args: + predictions: The predictions tensor. Of type float32. + labels: The labels tensor. Of type bool. + expected_result: The expected result (dict) that maps to tensors. + weights: Optional weights tensor. + """ + with self.test_session() as sess: + predictions_tensor = constant_op.constant( + predictions, dtype=dtypes_lib.float32) + labels_tensor = constant_op.constant(labels, dtype=dtypes_lib.bool) + weights_tensor = None + if weights: + weights_tensor = constant_op.constant(weights, dtype=dtypes_lib.float32) + gotten_result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions_tensor, + labels=labels_tensor, + num_thresholds=3, + weights=weights_tensor)) + + sess.run(variables.local_variables_initializer()) + sess.run(update_op) + + self._testResultsEqual(expected_result, gotten_result) + + def testVars(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool)) + _assert_metric_variables( + self, ('precision_recall_at_equal_thresholds/variables/tp_buckets:0', + 'precision_recall_at_equal_thresholds/variables/fp_buckets:0')) + + def testVarsWithName(self): + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=constant_op.constant([0.42], dtype=dtypes_lib.float32), + labels=constant_op.constant([True], dtype=dtypes_lib.bool), + name='foo') + _assert_metric_variables( + self, ('foo/variables/tp_buckets:0', 'foo/variables/fp_buckets:0')) + + def testValuesAreIdempotent(self): + predictions = constant_op.constant( + np.random.uniform(size=(10, 3)), dtype=dtypes_lib.float32) + labels = constant_op.constant( + np.random.uniform(size=(10, 3)) > 0.5, dtype=dtypes_lib.bool) + + result, update_op = ( + metric_ops.streaming_precision_recall_at_equal_thresholds( + predictions=predictions, labels=labels)) + + with self.test_session() as sess: + # Run several updates. + sess.run(variables.local_variables_initializer()) + for _ in range(3): + sess.run(update_op) + + # Then verify idempotency. + initial_result = {k: value.eval().tolist() for k, value in + result._asdict().items()} + for _ in range(3): + self._testResultsEqual(initial_result, result) + + def testAllTruePositives(self): + self._testCase([[1]], [[True]], { + 'tp': [1, 1, 1], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [1.0, 1.0, 1.0], + 'recall': [1.0, 1.0, 1.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllTrueNegatives(self): + self._testCase([[0]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 0, 0], + 'tn': [0, 1, 1], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalsePositives(self): + self._testCase([[1]], [[False]], { + 'tp': [0, 0, 0], + 'fp': [1, 1, 1], + 'tn': [0, 0, 0], + 'fn': [0, 0, 0], + 'precision': [0.0, 0.0, 0.0], + 'recall': [0.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testAllFalseNegatives(self): + self._testCase([[0]], [[True]], { + 'tp': [1, 0, 0], + 'fp': [0, 0, 0], + 'tn': [0, 0, 0], + 'fn': [0, 1, 1], + 'precision': [1.0, 0.0, 0.0], + 'recall': [1.0, 0.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValues(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [4, 3, 0], + 'fp': [2, 0, 0], + 'tn': [0, 2, 2], + 'fn': [0, 1, 4], + 'precision': [2.0 / 3.0, 1.0, 0.0], + 'recall': [1.0, 0.75, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }) + + def testManyValuesWithWeights(self): + self._testCase( + [[0.2, 0.3, 0.4, 0.6, 0.7, 0.8]], + [[True, False, False, True, True, True]], + { + 'tp': [1.5, 1.5, 0.0], + 'fp': [2.5, 0.0, 0.0], + 'tn': [0.0, 2.5, 2.5], + 'fn': [0.0, 0.0, 1.5], + 'precision': [0.375, 1.0, 0.0], + 'recall': [1.0, 1.0, 0.0], + 'thresholds': [0.0, 0.5, 1.0], + }, + weights=[[0.0, 0.5, 2.0, 0.0, 0.5, 1.0]]) + + class StreamingSpecificityAtSensitivityTest(test.TestCase): def setUp(self): @@ -1725,11 +2146,11 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), sensitivity=0.7) - _assert_local_variables(self, - ('specificity_at_sensitivity/true_positives:0', - 'specificity_at_sensitivity/false_negatives:0', - 'specificity_at_sensitivity/false_positives:0', - 'specificity_at_sensitivity/true_negatives:0')) + _assert_metric_variables(self, + ('specificity_at_sensitivity/true_positives:0', + 'specificity_at_sensitivity/false_negatives:0', + 'specificity_at_sensitivity/false_positives:0', + 'specificity_at_sensitivity/true_negatives:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -1753,7 +2174,7 @@ class StreamingSpecificityAtSensitivityTest(test.TestCase): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) specificity, update_op = metrics.streaming_specificity_at_sensitivity( predictions, labels, sensitivity=0.7) @@ -1861,11 +2282,11 @@ class StreamingSensitivityAtSpecificityTest(test.TestCase): predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), specificity=0.7) - _assert_local_variables(self, - ('sensitivity_at_specificity/true_positives:0', - 'sensitivity_at_specificity/false_negatives:0', - 'sensitivity_at_specificity/false_positives:0', - 'sensitivity_at_specificity/true_negatives:0')) + _assert_metric_variables(self, + ('sensitivity_at_specificity/true_positives:0', + 'sensitivity_at_specificity/false_negatives:0', + 'sensitivity_at_specificity/false_positives:0', + 'sensitivity_at_specificity/true_negatives:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -1978,64 +2399,697 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0]) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'precision_at_thresholds/true_positives:0', - 'precision_at_thresholds/false_positives:0',)) + 'precision_at_thresholds/false_positives:0', + )) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + prec, _ = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + rec, _ = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, precision_op = metrics.streaming_precision_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + _, recall_op = metrics.streaming_recall_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [precision_op, recall_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run([prec_op, rec_op]) + + # Then verify idempotency. + initial_prec = prec.eval() + initial_rec = rec.eval() + for _ in range(10): + self.assertAllClose(initial_prec, prec.eval()) + self.assertAllClose(initial_rec, rec.eval()) + + # TODO(nsilberman): fix tests (passing but incorrect). + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertEqual(1, prec.eval()) + self.assertEqual(1, rec.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.5, prec.eval()) + self.assertAlmostEqual(0.5, rec.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval()) + self.assertAlmostEqual(0, rec.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + prec, prec_op = metrics.streaming_precision_at_thresholds( + predictions, labels, thresholds, weights=weights) + rec, rec_op = metrics.streaming_recall_at_thresholds( + predictions, labels, thresholds, weights=weights) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(1.0, prec_low.eval(), places=5) + self.assertAlmostEqual(0.0, prec_high.eval(), places=5) + self.assertAlmostEqual(1.0, rec_low.eval(), places=5) + self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + + def testExtremeThresholds(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + prec_low = prec[0] + prec_high = prec[1] + rec_low = rec[0] + rec_high = rec[1] + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0.75, prec_low.eval()) + self.assertAlmostEqual(0.0, prec_high.eval()) + self.assertAlmostEqual(1.0, rec_low.eval()) + self.assertAlmostEqual(0.0, rec_high.eval()) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, + labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run([prec_op, rec_op]) + + self.assertAlmostEqual(0, prec.eval(), 6) + self.assertAlmostEqual(0, rec.eval(), 6) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] + + tp = 0 + fp = 0 + fn = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 1: + tp += 1 + else: + fp += 1 + else: + if labels[i] == 1: + fn += 1 + else: + tn += 1 + epsilon = 1e-7 + expected_prec = tp / (epsilon + tp + fp) + expected_rec = tp / (epsilon + tp + fn) + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + + with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, + tf_labels, + thresholds) + rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, + tf_labels, + thresholds) + + sess.run(variables.local_variables_initializer()) + for _ in range(int(num_samples / batch_size)): + sess.run([prec_op, rec_op]) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_prec, prec.eval(), 2) + self.assertAlmostEqual(expected_rec, rec.eval(), 2) + + +class StreamingFPRThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_metric_variables(self, ( + 'false_positive_rate_at_thresholds/false_positives:0', + 'false_positive_rate_at_thresholds/true_negatives:0', + )) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + fpr, _ = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [fpr]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0], + updates_collections=[my_collection_name]) + self.assertListEqual( + ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + thresholds = [0, 0.5, 1.0] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(fpr_op) + + # Then verify idempotency. + initial_fpr = fpr.eval() + for _ in range(10): + self.assertAllClose(initial_fpr, fpr.eval()) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertEqual(0, fpr.eval()) + + def testSomeCorrect(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.5, fpr.eval()) + + def testAllIncorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + with self.test_session() as sess: + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(1, fpr.eval()) + + def testWeights1d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testWeights2d(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) + labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) + weights = constant_op.constant( + [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) + thresholds = [0.5, 1.1] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds, weights=weights) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testExtremeThresholds(self): + with self.test_session() as sess: + predictions = constant_op.constant( + [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) + labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) + thresholds = [-1.0, 2.0] # lower/higher than any values + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + fpr_low = fpr[0] + fpr_high = fpr[1] + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(1.0, fpr_low.eval(), places=5) + self.assertAlmostEqual(0.0, fpr_high.eval(), places=5) + + def testZeroLabelsPredictions(self): + with self.test_session() as sess: + predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) + labels = array_ops.zeros([4]) + thresholds = [0.5] + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + predictions, labels, thresholds) + + sess.run(variables.local_variables_initializer()) + sess.run(fpr_op) + + self.assertAlmostEqual(0, fpr.eval(), 6) + + def testWithMultipleUpdates(self): + num_samples = 1000 + batch_size = 10 + num_batches = int(num_samples / batch_size) + + # Create the labels and data. + labels = np.random.randint(0, 2, size=(num_samples, 1)) + noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) + predictions = 0.4 + 0.2 * labels + noise + predictions[predictions > 1] = 1 + predictions[predictions < 0] = 0 + thresholds = [0.3] + + fp = 0 + tn = 0 + for i in range(num_samples): + if predictions[i] > thresholds[0]: + if labels[i] == 0: + fp += 1 + else: + if labels[i] == 0: + tn += 1 + epsilon = 1e-7 + expected_fpr = fp / (epsilon + fp + tn) + + labels = labels.astype(np.float32) + predictions = predictions.astype(np.float32) + + with self.test_session() as sess: + # Reshape the data so its easy to queue up: + predictions_batches = predictions.reshape((batch_size, num_batches)) + labels_batches = labels.reshape((batch_size, num_batches)) + + # Enqueue the data: + predictions_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + labels_queue = data_flow_ops.FIFOQueue( + num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) + + for i in range(int(num_batches)): + tf_prediction = constant_op.constant(predictions_batches[:, i]) + tf_label = constant_op.constant(labels_batches[:, i]) + sess.run([ + predictions_queue.enqueue(tf_prediction), + labels_queue.enqueue(tf_label) + ]) + + tf_predictions = predictions_queue.dequeue() + tf_labels = labels_queue.dequeue() + + fpr, fpr_op = metrics.streaming_false_positive_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) + + sess.run(variables.local_variables_initializer()) + for _ in range(int(num_samples / batch_size)): + sess.run(fpr_op) + # Since this is only approximate, we can't expect a 6 digits match. + # Although with higher number of samples/thresholds we should see the + # accuracy improving + self.assertAlmostEqual(expected_fpr, fpr.eval(), 2) + + +class RecallAtPrecisionTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7) + _assert_metric_variables(self, ('recall_at_precision/true_positives:0', + 'recall_at_precision/false_negatives:0', + 'recall_at_precision/false_positives:0', + 'recall_at_precision/true_negatives:0')) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7, + metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.recall_at_precision( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + precision=0.7, + updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testValueTensorIsIdempotent(self): + predictions = random_ops.random_uniform( + (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) + labels = random_ops.random_uniform( + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.7) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + + # Run several updates. + for _ in range(10): + sess.run(update_op) + + # Then verify idempotency. + initial_recall = recall.eval() + for _ in range(10): + self.assertAlmostEqual(initial_recall, recall.eval(), 5) + + def testAllCorrect(self): + inputs = np.random.randint(0, 2, size=(100, 1)) + + predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) + labels = constant_op.constant(inputs) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=1.0) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertEqual(1, sess.run(update_op)) + self.assertEqual(1, recall.eval()) + + def testSomeCorrectHighPrecision(self): + predictions_values = [1, .9, .8, .7, .6, .5, .4, .3] + labels_values = [1, 1, 1, 1, 0, 0, 0, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.8) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + self.assertAlmostEqual(0.8, sess.run(update_op)) + self.assertAlmostEqual(0.8, recall.eval()) + + def testSomeCorrectLowPrecision(self): + predictions_values = [1, .9, .8, .7, .6, .5, .4, .3, .2, .1] + labels_values = [1, 1, 0, 0, 0, 0, 0, 0, 0, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, precision=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + target_recall = 2.0 / 3.0 + self.assertAlmostEqual(target_recall, sess.run(update_op)) + self.assertAlmostEqual(target_recall, recall.eval()) + + def testWeighted(self): + predictions_values = [1, .9, .8, .7, .6] + labels_values = [1, 1, 0, 0, 1] + weights_values = [1, 1, 3, 4, 1] + + predictions = constant_op.constant( + predictions_values, dtype=dtypes_lib.float32) + labels = constant_op.constant(labels_values) + weights = constant_op.constant(weights_values) + recall, update_op = metrics.recall_at_precision( + predictions, labels, weights=weights, precision=0.4) + + with self.test_session() as sess: + sess.run(variables.local_variables_initializer()) + target_recall = 2.0 / 3.0 + self.assertAlmostEqual(target_recall, sess.run(update_op)) + self.assertAlmostEqual(target_recall, recall.eval()) + + +class StreamingFNRThresholdsTest(test.TestCase): + + def setUp(self): + np.random.seed(1) + ops.reset_default_graph() + + def testVars(self): + metrics.streaming_false_negative_rate_at_thresholds( + predictions=array_ops.ones((10, 1)), + labels=array_ops.ones((10, 1)), + thresholds=[0, 0.5, 1.0]) + _assert_metric_variables(self, ( + 'false_negative_rate_at_thresholds/false_negatives:0', + 'false_negative_rate_at_thresholds/true_positives:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' - prec, _ = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - metrics_collections=[my_collection_name]) - rec, _ = metrics.streaming_recall_at_thresholds( + fnr, _ = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], metrics_collections=[my_collection_name]) - self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) + self.assertListEqual(ops.get_collection(my_collection_name), [fnr]) def testUpdatesCollection(self): my_collection_name = '__updates__' - _, precision_op = metrics.streaming_precision_at_thresholds( - predictions=array_ops.ones((10, 1)), - labels=array_ops.ones((10, 1)), - thresholds=[0, 0.5, 1.0], - updates_collections=[my_collection_name]) - _, recall_op = metrics.streaming_recall_at_thresholds( + _, update_op = metrics.streaming_false_negative_rate_at_thresholds( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), thresholds=[0, 0.5, 1.0], updates_collections=[my_collection_name]) self.assertListEqual( - ops.get_collection(my_collection_name), [precision_op, recall_op]) + ops.get_collection(my_collection_name), [update_op]) def testValueTensorIsIdempotent(self): predictions = random_ops.random_uniform( (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) labels = random_ops.random_uniform( - (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=2) + (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=2) thresholds = [0, 0.5, 1.0] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) with self.test_session() as sess: sess.run(variables.local_variables_initializer()) - # Run several updates, then verify idempotency. - sess.run([prec_op, rec_op]) - initial_prec = prec.eval() - initial_rec = rec.eval() + # Run several updates. for _ in range(10): - sess.run([prec_op, rec_op]) - self.assertAllClose(initial_prec, prec.eval()) - self.assertAllClose(initial_rec, rec.eval()) + sess.run(fnr_op) + + # Then verify idempotency. + initial_fnr = fnr.eval() + for _ in range(10): + self.assertAllClose(initial_fnr, fnr.eval()) - # TODO(nsilberman): fix tests (passing but incorrect). def testAllCorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2043,17 +3097,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(inputs) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertEqual(1, prec.eval()) - self.assertEqual(1, rec.eval()) + self.assertEqual(0, fnr.eval()) def testSomeCorrect(self): with self.test_session() as sess: @@ -2061,17 +3111,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.5, prec.eval()) - self.assertAlmostEqual(0.5, rec.eval()) + self.assertAlmostEqual(0.5, fnr.eval()) def testAllIncorrect(self): inputs = np.random.randint(0, 2, size=(100, 1)) @@ -2080,17 +3126,13 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval()) - self.assertAlmostEqual(0, rec.eval()) + self.assertAlmostEqual(1, fnr.eval()) def testWeights1d(self): with self.test_session() as sess: @@ -2100,27 +3142,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testWeights2d(self): with self.test_session() as sess: @@ -2130,27 +3162,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): weights = constant_op.constant( [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) thresholds = [0.5, 1.1] - prec, prec_op = metrics.streaming_precision_at_thresholds( - predictions, labels, thresholds, weights=weights) - rec, rec_op = metrics.streaming_recall_at_thresholds( + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( predictions, labels, thresholds, weights=weights) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - prec_low = array_ops.reshape(prec_low, shape=()) - prec_high = array_ops.reshape(prec_high, shape=()) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) - rec_low = array_ops.reshape(rec_low, shape=()) - rec_high = array_ops.reshape(rec_high, shape=()) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(1.0, prec_low.eval(), places=5) - self.assertAlmostEqual(0.0, prec_high.eval(), places=5) - self.assertAlmostEqual(1.0, rec_low.eval(), places=5) - self.assertAlmostEqual(0.0, rec_high.eval(), places=5) + self.assertAlmostEqual(0.0, fnr_low.eval(), places=5) + self.assertAlmostEqual(1.0, fnr_high.eval(), places=5) def testExtremeThresholds(self): with self.test_session() as sess: @@ -2158,41 +3180,30 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) thresholds = [-1.0, 2.0] # lower/higher than any values - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) - [prec_low, prec_high] = array_ops.split( - value=prec, num_or_size_splits=2, axis=0) - [rec_low, rec_high] = array_ops.split( - value=rec, num_or_size_splits=2, axis=0) + fnr_low = fnr[0] + fnr_high = fnr[1] sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0.75, prec_low.eval()) - self.assertAlmostEqual(0.0, prec_high.eval()) - self.assertAlmostEqual(1.0, rec_low.eval()) - self.assertAlmostEqual(0.0, rec_high.eval()) + self.assertAlmostEqual(0.0, fnr_low.eval()) + self.assertAlmostEqual(1.0, fnr_high.eval()) def testZeroLabelsPredictions(self): with self.test_session() as sess: predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) labels = array_ops.zeros([4]) thresholds = [0.5] - prec, prec_op = metrics.streaming_precision_at_thresholds(predictions, - labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(predictions, labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + predictions, labels, thresholds) sess.run(variables.local_variables_initializer()) - sess.run([prec_op, rec_op]) + sess.run(fnr_op) - self.assertAlmostEqual(0, prec.eval(), 6) - self.assertAlmostEqual(0, rec.eval(), 6) + self.assertAlmostEqual(0, fnr.eval(), 6) def testWithMultipleUpdates(self): num_samples = 1000 @@ -2207,24 +3218,17 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): predictions[predictions < 0] = 0 thresholds = [0.3] - tp = 0 - fp = 0 fn = 0 - tn = 0 + tp = 0 for i in range(num_samples): if predictions[i] > thresholds[0]: if labels[i] == 1: tp += 1 - else: - fp += 1 else: if labels[i] == 1: fn += 1 - else: - tn += 1 epsilon = 1e-7 - expected_prec = tp / (epsilon + tp + fp) - expected_rec = tp / (epsilon + tp + fn) + expected_fnr = fn / (epsilon + fn + tp) labels = labels.astype(np.float32) predictions = predictions.astype(np.float32) @@ -2251,21 +3255,16 @@ class StreamingPrecisionRecallThresholdsTest(test.TestCase): tf_predictions = predictions_queue.dequeue() tf_labels = labels_queue.dequeue() - prec, prec_op = metrics.streaming_precision_at_thresholds(tf_predictions, - tf_labels, - thresholds) - rec, rec_op = metrics.streaming_recall_at_thresholds(tf_predictions, - tf_labels, - thresholds) + fnr, fnr_op = metrics.streaming_false_negative_rate_at_thresholds( + tf_predictions, tf_labels, thresholds) sess.run(variables.local_variables_initializer()) for _ in range(int(num_samples / batch_size)): - sess.run([prec_op, rec_op]) + sess.run(fnr_op) # Since this is only approximate, we can't expect a 6 digits match. # Although with higher number of samples/thresholds we should see the # accuracy improving - self.assertAlmostEqual(expected_prec, prec.eval(), 2) - self.assertAlmostEqual(expected_rec, rec.eval(), 2) + self.assertAlmostEqual(expected_fnr, fnr.eval(), 2) # TODO(ptucker): Remove when we remove `streaming_recall_at_k`. @@ -2291,8 +3290,8 @@ class StreamingRecallAtKTest(test.TestCase): labels=array_ops.ones( (self._batch_size,), dtype=dtypes_lib.int32), k=1) - _assert_local_variables(self, ('recall_at_1/count:0', - 'recall_at_1/total:0')) + _assert_metric_variables(self, + ('recall_at_1/count:0', 'recall_at_1/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -3783,8 +4782,8 @@ class StreamingMeanAbsoluteErrorTest(test.TestCase): def testVars(self): metrics.streaming_mean_absolute_error( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, ('mean_absolute_error/count:0', - 'mean_absolute_error/total:0')) + _assert_metric_variables( + self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -3846,8 +4845,8 @@ class StreamingMeanRelativeErrorTest(test.TestCase): predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)), normalizer=array_ops.ones((10, 1))) - _assert_local_variables(self, ('mean_relative_error/count:0', - 'mean_relative_error/total:0')) + _assert_metric_variables( + self, ('mean_relative_error/count:0', 'mean_relative_error/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -3929,8 +4928,8 @@ class StreamingMeanSquaredErrorTest(test.TestCase): def testVars(self): metrics.streaming_mean_squared_error( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, ('mean_squared_error/count:0', - 'mean_squared_error/total:0')) + _assert_metric_variables( + self, ('mean_squared_error/count:0', 'mean_squared_error/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4109,8 +5108,9 @@ class StreamingRootMeanSquaredErrorTest(test.TestCase): def testVars(self): metrics.streaming_root_mean_squared_error( predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) - _assert_local_variables(self, ('root_mean_squared_error/count:0', - 'root_mean_squared_error/total:0')) + _assert_metric_variables( + self, + ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0')) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4202,11 +5202,12 @@ class StreamingCovarianceTest(test.TestCase): predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( [10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'covariance/comoment:0', 'covariance/count:0', 'covariance/mean_label:0', - 'covariance/mean_prediction:0',)) + 'covariance/mean_prediction:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4371,7 +5372,7 @@ class StreamingPearsonRTest(test.TestCase): predictions=math_ops.to_float(math_ops.range(10)) + array_ops.ones( [10, 10]), labels=math_ops.to_float(math_ops.range(10)) + array_ops.ones([10, 10])) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'pearson_r/covariance/comoment:0', 'pearson_r/covariance/count:0', 'pearson_r/covariance/mean_label:0', @@ -4383,7 +5384,8 @@ class StreamingPearsonRTest(test.TestCase): 'pearson_r/variance_predictions/comoment:0', 'pearson_r/variance_predictions/count:0', 'pearson_r/variance_predictions/mean_label:0', - 'pearson_r/variance_predictions/mean_prediction:0',)) + 'pearson_r/variance_predictions/mean_prediction:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4596,9 +5598,10 @@ class StreamingMeanCosineDistanceTest(test.TestCase): predictions=array_ops.ones((10, 3)), labels=array_ops.ones((10, 3)), dim=1) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'mean_cosine_distance/count:0', - 'mean_cosine_distance/total:0',)) + 'mean_cosine_distance/total:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4737,9 +5740,10 @@ class PcntBelowThreshTest(test.TestCase): def testVars(self): metrics.streaming_percentage_less(values=array_ops.ones((10,)), threshold=2) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'percentage_below_threshold/count:0', - 'percentage_below_threshold/total:0',)) + 'percentage_below_threshold/total:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -4812,7 +5816,7 @@ class StreamingMeanIOUTest(test.TestCase): predictions=array_ops.ones([10, 1]), labels=array_ops.ones([10, 1]), num_classes=2) - _assert_local_variables(self, ('mean_iou/total_confusion_matrix:0',)) + _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',)) def testMetricsCollections(self): my_collection_name = '__metrics__' @@ -4978,7 +5982,7 @@ class StreamingMeanIOUTest(test.TestCase): sess.run(variables.local_variables_initializer()) for _ in range(5): sess.run(update_op) - desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0, 0.]) + desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0]) self.assertAlmostEqual(desired_output, miou.eval()) def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): @@ -5060,6 +6064,58 @@ class StreamingMeanIOUTest(test.TestCase): desired_miou = np.mean([2. / 4., 4. / 6.]) self.assertAlmostEqual(desired_miou, miou.eval()) + def testMissingClassInLabels(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 2, 1, 1, 0], + [0, 1, 2, 2, 0, 1]], + [[0, 0, 2, 1, 1, 1], + [1, 1, 2, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), + miou.eval()) + + def testMissingClassOverallSmall(self): + labels = constant_op.constant([0]) + predictions = constant_op.constant([0]) + num_classes = 2 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[1, 0], [0, 0]], update_op.eval()) + self.assertAlmostEqual(1, miou.eval()) + + def testMissingClassOverallLarge(self): + labels = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 0, 0, 0, 0, 1]], + [[1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 0]]]) + predictions = constant_op.constant([ + [[0, 0, 1, 1, 0, 0], + [1, 1, 0, 0, 1, 1]], + [[0, 0, 0, 1, 1, 1], + [1, 1, 1, 0, 0, 0]]]) + num_classes = 3 + with self.test_session() as sess: + miou, update_op = metrics.streaming_mean_iou( + predictions, labels, num_classes) + sess.run(variables.local_variables_initializer()) + self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval()) + self.assertAlmostEqual( + 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval()) + class StreamingConcatTest(test.TestCase): @@ -5068,9 +6124,10 @@ class StreamingConcatTest(test.TestCase): def testVars(self): metrics.streaming_concat(values=array_ops.ones((10,))) - _assert_local_variables(self, ( + _assert_metric_variables(self, ( 'streaming_concat/array:0', - 'streaming_concat/size:0',)) + 'streaming_concat/size:0', + )) def testMetricsCollection(self): my_collection_name = '__metrics__' @@ -5240,5 +6297,163 @@ class AggregateMetricMapTest(test.TestCase): self.assertEqual(4, names_to_values['m2'].eval()) +class CountTest(test.TestCase): + + def setUp(self): + ops.reset_default_graph() + + def testVars(self): + metrics.count(array_ops.ones([4, 3])) + _assert_metric_variables(self, ['count/count:0']) + + def testMetricsCollection(self): + my_collection_name = '__metrics__' + mean, _ = metrics.count( + array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [mean]) + + def testUpdatesCollection(self): + my_collection_name = '__updates__' + _, update_op = metrics.count( + array_ops.ones([4, 3]), updates_collections=[my_collection_name]) + self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) + + def testBasic(self): + with self.test_session() as sess: + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + result, update_op = metrics.count(values) + + sess.run(variables.local_variables_initializer()) + for _ in range(4): + sess.run(update_op) + self.assertAlmostEqual(8.0, sess.run(result), 5) + + def testUpdateOpsReturnsCurrentValue(self): + with self.test_session() as sess: + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + result, update_op = metrics.count(values) + + sess.run(variables.local_variables_initializer()) + + self.assertAlmostEqual(2.0, sess.run(update_op), 5) + self.assertAlmostEqual(4.0, sess.run(update_op), 5) + self.assertAlmostEqual(6.0, sess.run(update_op), 5) + self.assertAlmostEqual(8.0, sess.run(update_op), 5) + + self.assertAlmostEqual(8.0, sess.run(result), 5) + + def test1dWeightedValues(self): + with self.test_session() as sess: + # Create the queue that populates the values. + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) + _enqueue_vector(sess, weights_queue, [0.5]) + _enqueue_vector(sess, weights_queue, [0]) + _enqueue_vector(sess, weights_queue, [0]) + _enqueue_vector(sess, weights_queue, [1.2]) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for _ in range(4): + update_op.eval() + self.assertAlmostEqual(3.4, result.eval(), 5) + + def test1dWeightedValues_placeholders(self): + with self.test_session() as sess: + # Create the queue that populates the values. + feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) + values = array_ops.placeholder(dtype=dtypes_lib.float32) + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1,)) + _enqueue_vector(sess, weights_queue, 0.5, shape=(1,)) + _enqueue_vector(sess, weights_queue, 0, shape=(1,)) + _enqueue_vector(sess, weights_queue, 0, shape=(1,)) + _enqueue_vector(sess, weights_queue, 1.2, shape=(1,)) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for i in range(4): + update_op.eval(feed_dict={values: feed_values[i]}) + self.assertAlmostEqual(3.4, result.eval(), 5) + + def test2dWeightedValues(self): + with self.test_session() as sess: + # Create the queue that populates the values. + values_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, values_queue, [0, 1]) + _enqueue_vector(sess, values_queue, [-4.2, 9.1]) + _enqueue_vector(sess, values_queue, [6.5, 0]) + _enqueue_vector(sess, values_queue, [-3.2, 4.0]) + values = values_queue.dequeue() + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) + _enqueue_vector(sess, weights_queue, [1.1, 1]) + _enqueue_vector(sess, weights_queue, [1, 0]) + _enqueue_vector(sess, weights_queue, [0, 1]) + _enqueue_vector(sess, weights_queue, [0, 0]) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for _ in range(4): + update_op.eval() + self.assertAlmostEqual(4.1, result.eval(), 5) + + def test2dWeightedValues_placeholders(self): + with self.test_session() as sess: + # Create the queue that populates the values. + feed_values = ((0, 1), (-4.2, 9.1), (6.5, 0), (-3.2, 4.0)) + values = array_ops.placeholder(dtype=dtypes_lib.float32) + + # Create the queue that populates the weighted labels. + weights_queue = data_flow_ops.FIFOQueue( + 4, dtypes=dtypes_lib.float32, shapes=(2,)) + _enqueue_vector(sess, weights_queue, [1.1, 1], shape=(2,)) + _enqueue_vector(sess, weights_queue, [1, 0], shape=(2,)) + _enqueue_vector(sess, weights_queue, [0, 1], shape=(2,)) + _enqueue_vector(sess, weights_queue, [0, 0], shape=(2,)) + weights = weights_queue.dequeue() + + result, update_op = metrics.count(values, weights) + + variables.local_variables_initializer().run() + for i in range(4): + update_op.eval(feed_dict={values: feed_values[i]}) + self.assertAlmostEqual(4.1, result.eval(), 5) + + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..ca3f13479ed32e9ab3d43dfe9a392ef8466ce5f2 --- /dev/null +++ b/tensorflow/contrib/model_pruning/BUILD @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "core_layers", + srcs = ["python/layers/core_layers.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:layers", + "//tensorflow/python:ops", + "//tensorflow/python:platform", + ], +) + +py_library( + name = "layers", + srcs = ["python/layers/layers.py"], + srcs_version = "PY2AND3", + deps = [ + ":core_layers", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", + "//third_party/py/numpy", + ], +) + +py_test( + name = "layers_test", + size = "small", + srcs = ["python/layers/layers_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":layers", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "learning", + srcs = ["python/learning.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/slim", + ], +) + +py_library( + name = "rnn_cells", + srcs = ["python/layers/rnn_cells.py"], + srcs_version = "PY2AND3", + deps = [ + ":core_layers", + ], +) + +py_library( + name = "pruning", + srcs = ["python/pruning.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":core_layers", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) + +py_test( + name = "pruning_test", + size = "small", + srcs = ["python/pruning_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pruning", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "rnn_cells_test", + size = "small", + srcs = ["python/layers/rnn_cells_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pruning", + ":rnn_cells", + "//tensorflow/python:client_testlib", + ], +) + +py_library( + name = "init_py", + srcs = ["__init__.py"], + srcs_version = "PY2AND3", +) + +# Top-level library +py_library( + name = "model_pruning", + srcs_version = "PY2AND3", + deps = [ + ":init_py", + ":layers", + ":learning", + ":pruning", + ":rnn_cells", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md new file mode 100644 index 0000000000000000000000000000000000000000..a8427e60144445c008d032ff2cbfd801d294974c --- /dev/null +++ b/tensorflow/contrib/model_pruning/README.md @@ -0,0 +1,195 @@ +# Model pruning: Training tensorflow models to have masked connections + +This document describes the API that facilitates magnitude-based pruning of +neural network's weight tensors. The API helps inject necessary tensorflow op +into the training graph so the model can be pruned while it is being trained. + +### Model creation + +The first step involves adding mask and threshold variables to the layers that +need to undergo pruning. The variable mask is the same shape as the layer's +weight tensor and determines which of the weights participate in the forward +execution of the graph. This can be achieved by wrapping the weight tensor of +the layer with the `apply_mask` function provided in +[pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/pruning.py). +For example: + +```python +conv = tf.nn.conv2d(images, pruning.apply_mask(weights), stride, padding) +``` + +This creates a convolutional layer with additional variables mask and threshold +as shown below: ![Convolutional layer with mask and +threshold](./mask.png "Convolutional layer with mask and threshold") + +Alternatively, the API also provides variant of tensorflow layers with these +auxiliary variables built-in (see +[layers](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers)) +. Layers currently supported: + +* [layers.masked_conv2d](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=83) + +* [layers.masked_fully_connected](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/layers.py?l=241) + +* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154) + +### Adding pruning ops to the training graph + +The pruning library allows for specification of the following hyper parameters: + +| Hyperparameter | Type | Default | Description | +| ---------------------------- | ------- | ------------- | -------------- | +| name | string | model_pruning | Name of the | +: : : : pruning : +: : : : specification. : +: : : : Used for : +: : : : adding : +: : : : summaries and : +: : : : ops under a : +: : : : common : +: : : : tensorflow : +: : : : name_scope : +| begin_pruning_step | integer | 0 | The global | +: : : : step at which : +: : : : to begin : +: : : : pruning : +| end_pruning_step | integer | -1 | The global | +: : : : step at which : +: : : : to terminate : +: : : : pruning. : +: : : : Defaults to -1 : +: : : : implying that : +: : : : pruning : +: : : : continues till : +: : : : the training : +: : : : stops : +| do_not_prune | list of | [""] | list of layers | +: : strings : : that are not : +: : : : pruned : +| threshold_decay | float | 0.9 | The decay | +: : : : factor to use : +: : : : for : +: : : : exponential : +: : : : decay of the : +: : : : thresholds : +| pruning_frequency | integer | 10 | How often | +: : : : should the : +: : : : masks be : +: : : : updated? (in # : +: : : : of : +: : : : global_steps). : +| nbins | integer | 255 | Number of bins | +: : : : to use for : +: : : : histogram : +: : : : computation : +| initial_sparsity | float | 0.0 | Initial | +: : : : sparsity value : +| target_sparsity | float | 0.5 | Target | +: : : : sparsity value : +| sparsity_function_begin_step | integer | 0 | The global | +: : : : step at this : +: : : : which the : +: : : : gradual : +: : : : sparsity : +: : : : function : +: : : : begins to take : +: : : : effect : +| sparsity_function_end_step | integer | 100 | The global | +: : : : step used as : +: : : : the end point : +: : : : for the : +: : : : gradual : +: : : : sparsity : +: : : : function : +| sparsity_function_exponent | float | 3.0 | exponent = 1 | +: : : : is linearly : +: : : : varying : +: : : : sparsity : +: : : : between : +: : : : initial and : +: : : : final. : +: : : : exponent > 1 : +: : : : varies more : +: : : : slowly towards : +: : : : the end than : +: : : : the beginning : + +The sparsity $$s_t$$ at global step $$t$$ is given by: + +$$ s_{t}=s_{f}+\left(s_{i}-s_{f}\right)\left(1-\frac{t-t_{0}}{n\Delta t}\right)^{3} $$ + +The interval between sparsity_function_begin_step and sparsity_function_end_step +is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta +t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$ +is the sparsity_function_begin_step. In this equation, the +sparsity_function_exponent is set to 3. +### Adding pruning ops to the training graph + +The final step involves adding ops to the training graph that monitors the +distribution of the layer's weight magnitudes and determines the layer threshold +such masking all the weights below this threshold achieves the sparsity level +desired for the current training step. This can be achieved as follows: + +```python +tf.app.flags.DEFINE_string( + 'pruning_hparams', '', + """Comma separated list of pruning-related hyperparameters""") + +with tf.graph.as_default(): + + # Create global step variable + global_step = tf.train.get_global_step() + + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) + + # Create a pruning object using the pruning specification + p = pruning.Pruning(pruning_hparams, global_step=global_step) + + # Add conditional mask update op. Executing this op will update all + # the masks in the graph if the current global step is in the range + # [begin_pruning_step, end_pruning_step] as specified by the pruning spec + mask_update_op = p.conditional_mask_update_op() + + # Add summaries to keep track of the sparsity in different layers during training + p.add_pruning_summaries() + + with tf.train.MonitoredTrainingSession(...) as mon_sess: + # Run the usual training op in the tf session + mon_sess.run(train_op) + + # Update the masks by running the mask_update_op + mon_sess.run(mask_update_op) + +``` + +## Example: Pruning and training deep CNNs on the cifar10 dataset + +Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural +network architecture, setting up inputs etc. The additional changes needed to +incorporate pruning are captured in the following: + +* [cifar10_pruning.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py) + creates a deep CNN with the same architecture, but adds mask and threshold + variables for each of the weight tensors in the convolutional and + locally-connected layers. + +* [cifar10_train.py](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py) + add pruning ops to the training graph as described above. + +To train the pruned version of cifar10: + +```bash +$ examples_dir=contrib/model_pruning/examples +$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval} +$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000 +``` + +Eval: + +```shell +$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once +``` + +TODO(suyoggupta): Add figures showing the sparsity function, sparsity for +different layers etc. diff --git a/tensorflow/contrib/model_pruning/__init__.py b/tensorflow/contrib/model_pruning/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d32bedbcd6b63bc8e473a9e9d1c8e0753877e6f8 --- /dev/null +++ b/tensorflow/contrib/model_pruning/__init__.py @@ -0,0 +1,47 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Model pruning implementation in tensorflow.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.model_pruning.python.layers.layers import masked_conv2d +from tensorflow.contrib.model_pruning.python.layers.layers import masked_convolution +from tensorflow.contrib.model_pruning.python.layers.layers import masked_fully_connected +from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedBasicLSTMCell +from tensorflow.contrib.model_pruning.python.layers.rnn_cells import MaskedLSTMCell +from tensorflow.contrib.model_pruning.python.learning import train +from tensorflow.contrib.model_pruning.python.pruning import apply_mask +from tensorflow.contrib.model_pruning.python.pruning import get_masked_weights +from tensorflow.contrib.model_pruning.python.pruning import get_masks +from tensorflow.contrib.model_pruning.python.pruning import get_pruning_hparams +from tensorflow.contrib.model_pruning.python.pruning import get_thresholds +from tensorflow.contrib.model_pruning.python.pruning import get_weight_sparsity +from tensorflow.contrib.model_pruning.python.pruning import get_weights +from tensorflow.contrib.model_pruning.python.pruning import Pruning +# pylint: enable=unused-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'masked_convolution', 'masked_conv2d', 'masked_fully_connected', + 'MaskedBasicLSTMCell', 'MaskedLSTMCell', 'train', 'apply_mask', + 'get_masked_weights', 'get_masks', 'get_pruning_hparams', 'get_thresholds', + 'get_weights', 'get_weight_sparsity', 'Pruning' +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/BUILD b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..299278ae7556253fd3c22724e51dd14963a873e2 --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/BUILD @@ -0,0 +1,77 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Description: +# Example TensorFlow models for CIFAR-10 + +package( + default_visibility = [ + "//tensorflow:internal", + ], +) + +licenses(["notice"]) # Apache 2.0 + +py_library( + name = "cifar10_input", + srcs = ["cifar10_input.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "cifar10_pruning", + srcs = ["cifar10_pruning.py"], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_input", + "//tensorflow:tensorflow_py", + ], +) + +py_binary( + name = "cifar10_eval", + srcs = [ + "cifar10_eval.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_pruning", + ], +) + +py_binary( + name = "cifar10_train", + srcs = [ + "cifar10_train.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":cifar10_pruning", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..d72b2a1dca5de26b59c81c082ff7a42e9a4f4357 --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_eval.py @@ -0,0 +1,178 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Evaluation for CIFAR-10. + +Accuracy: +cifar10_train.py achieves 83.0% accuracy after 100K steps (256 epochs +of data) as judged by cifar10_eval.py. + +Speed: +On a single Tesla K40, cifar10_train.py processes a single batch of 128 images +in 0.25-0.35 sec (i.e. 350 - 600 images /sec). The model reaches ~86% +accuracy after 100K steps in 8 hours of training time. + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + +http://tensorflow.org/tutorials/deep_cnn/ +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import math +import sys +import time + +import numpy as np +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10 + +FLAGS = None + + +def eval_once(saver, summary_writer, top_k_op, summary_op): + """Run Eval once. + + Args: + saver: Saver. + summary_writer: Summary writer. + top_k_op: Top K op. + summary_op: Summary op. + """ + with tf.Session() as sess: + ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir) + if ckpt and ckpt.model_checkpoint_path: + # Restores from checkpoint + saver.restore(sess, ckpt.model_checkpoint_path) + # Assuming model_checkpoint_path looks something like: + # /my-favorite-path/cifar10_train/model.ckpt-0, + # extract global_step from it. + global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] + else: + print('No checkpoint file found') + return + + # Start the queue runners. + coord = tf.train.Coordinator() + try: + threads = [] + for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS): + threads.extend(qr.create_threads(sess, coord=coord, daemon=True, + start=True)) + + num_iter = int(math.ceil(FLAGS.num_examples / 128)) + true_count = 0 # Counts the number of correct predictions. + total_sample_count = num_iter * 128 + step = 0 + while step < num_iter and not coord.should_stop(): + predictions = sess.run([top_k_op]) + true_count += np.sum(predictions) + step += 1 + + # Compute precision @ 1. + precision = true_count / total_sample_count + print('%s: precision @ 1 = %.3f' % (datetime.datetime.now(), precision)) + + summary = tf.Summary() + summary.ParseFromString(sess.run(summary_op)) + summary.value.add(tag='Precision @ 1', simple_value=precision) + summary_writer.add_summary(summary, global_step) + except Exception as e: # pylint: disable=broad-except + coord.request_stop(e) + + coord.request_stop() + coord.join(threads, stop_grace_period_secs=10) + + +def evaluate(): + """Eval CIFAR-10 for a number of steps.""" + with tf.Graph().as_default() as g: + # Get images and labels for CIFAR-10. + eval_data = FLAGS.eval_data == 'test' + images, labels = cifar10.inputs(eval_data=eval_data) + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate predictions. + top_k_op = tf.nn.in_top_k(logits, labels, 1) + + # Restore the moving average version of the learned variables for eval. + variable_averages = tf.train.ExponentialMovingAverage( + cifar10.MOVING_AVERAGE_DECAY) + variables_to_restore = variable_averages.variables_to_restore() + saver = tf.train.Saver(variables_to_restore) + + # Build the summary operation based on the TF collection of Summaries. + summary_op = tf.summary.merge_all() + + summary_writer = tf.summary.FileWriter(FLAGS.eval_dir, g) + + while True: + eval_once(saver, summary_writer, top_k_op, summary_op) + if FLAGS.run_once: + break + time.sleep(FLAGS.eval_interval_secs) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if tf.gfile.Exists(FLAGS.eval_dir): + tf.gfile.DeleteRecursively(FLAGS.eval_dir) + tf.gfile.MakeDirs(FLAGS.eval_dir) + evaluate() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--eval_dir', + type=str, + default='/tmp/cifar10_eval', + help='Directory where to write event logs.') + parser.add_argument( + '--eval_data', + type=str, + default='test', + help="""Either 'test' or 'train_eval'.""") + parser.add_argument( + '--checkpoint_dir', + type=str, + default='/tmp/cifar10_train', + help="""Directory where to read model checkpoints.""") + parser.add_argument( + '--eval_interval_secs', + type=int, + default=60 * 5, + help='How often to run the eval.') + parser.add_argument( + '--num_examples', + type=int, + default=10000, + help='Number of examples to run.') + parser.add_argument( + '--run_once', + type=bool, + default=False, + help='Whether to run eval only once.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py new file mode 100644 index 0000000000000000000000000000000000000000..d07fece4bc668612d517e8dcaab1a35451a0238e --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_input.py @@ -0,0 +1,256 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Routine for decoding the CIFAR-10 binary file format.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from six.moves import xrange # pylint: disable=redefined-builtin +import tensorflow as tf + +# Process images of this size. Note that this differs from the original CIFAR +# image size of 32 x 32. If one alters this number, then the entire model +# architecture will change and any model would need to be retrained. +IMAGE_SIZE = 24 + +# Global constants describing the CIFAR-10 data set. +NUM_CLASSES = 10 +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = 50000 +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = 10000 + + +def read_cifar10(filename_queue): + """Reads and parses examples from CIFAR10 data files. + + Recommendation: if you want N-way read parallelism, call this function + N times. This will give you N independent Readers reading different + files & positions within those files, which will give better mixing of + examples. + + Args: + filename_queue: A queue of strings with the filenames to read from. + + Returns: + An object representing a single example, with the following fields: + height: number of rows in the result (32) + width: number of columns in the result (32) + depth: number of color channels in the result (3) + key: a scalar string Tensor describing the filename & record number + for this example. + label: an int32 Tensor with the label in the range 0..9. + uint8image: a [height, width, depth] uint8 Tensor with the image data + """ + + class CIFAR10Record(object): + pass + result = CIFAR10Record() + + # Dimensions of the images in the CIFAR-10 dataset. + # See http://www.cs.toronto.edu/~kriz/cifar.html for a description of the + # input format. + label_bytes = 1 # 2 for CIFAR-100 + result.height = 32 + result.width = 32 + result.depth = 3 + image_bytes = result.height * result.width * result.depth + # Every record consists of a label followed by the image, with a + # fixed number of bytes for each. + record_bytes = label_bytes + image_bytes + + # Read a record, getting filenames from the filename_queue. No + # header or footer in the CIFAR-10 format, so we leave header_bytes + # and footer_bytes at their default of 0. + reader = tf.FixedLengthRecordReader(record_bytes=record_bytes) + result.key, value = reader.read(filename_queue) + + # Convert from a string to a vector of uint8 that is record_bytes long. + record_bytes = tf.decode_raw(value, tf.uint8) + + # The first bytes represent the label, which we convert from uint8->int32. + result.label = tf.cast( + tf.strided_slice(record_bytes, [0], [label_bytes]), tf.int32) + + # The remaining bytes after the label represent the image, which we reshape + # from [depth * height * width] to [depth, height, width]. + depth_major = tf.reshape( + tf.strided_slice(record_bytes, [label_bytes], + [label_bytes + image_bytes]), + [result.depth, result.height, result.width]) + # Convert from [depth, height, width] to [height, width, depth]. + result.uint8image = tf.transpose(depth_major, [1, 2, 0]) + + return result + + +def _generate_image_and_label_batch(image, label, min_queue_examples, + batch_size, shuffle): + """Construct a queued batch of images and labels. + + Args: + image: 3-D Tensor of [height, width, 3] of type.float32. + label: 1-D Tensor of type.int32 + min_queue_examples: int32, minimum number of samples to retain + in the queue that provides of batches of examples. + batch_size: Number of images per batch. + shuffle: boolean indicating whether to use a shuffling queue. + + Returns: + images: Images. 4D tensor of [batch_size, height, width, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + # Create a queue that shuffles the examples, and then + # read 'batch_size' images + labels from the example queue. + num_preprocess_threads = 16 + if shuffle: + images, label_batch = tf.train.shuffle_batch( + [image, label], + batch_size=batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * batch_size, + min_after_dequeue=min_queue_examples) + else: + images, label_batch = tf.train.batch( + [image, label], + batch_size=batch_size, + num_threads=num_preprocess_threads, + capacity=min_queue_examples + 3 * batch_size) + + # Display the training images in the visualizer. + tf.summary.image('images', images) + + return images, tf.reshape(label_batch, [batch_size]) + + +def distorted_inputs(data_dir, batch_size): + """Construct distorted input for CIFAR training using the Reader ops. + + Args: + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + for f in filenames: + if not tf.gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for training the network. Note the many random + # distortions applied to the image. + + # Randomly crop a [height, width] section of the image. + distorted_image = tf.random_crop(reshaped_image, [height, width, 3]) + + # Randomly flip the image horizontally. + distorted_image = tf.image.random_flip_left_right(distorted_image) + + # Because these operations are not commutative, consider randomizing + # the order their operation. + distorted_image = tf.image.random_brightness(distorted_image, + max_delta=63) + distorted_image = tf.image.random_contrast(distorted_image, + lower=0.2, upper=1.8) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_standardization(distorted_image) + + # Set the shapes of tensors. + float_image.set_shape([height, width, 3]) + read_input.label.set_shape([1]) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN * + min_fraction_of_examples_in_queue) + print ('Filling queue with %d CIFAR images before starting to train. ' + 'This will take a few minutes.' % min_queue_examples) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size, + shuffle=True) + + +def inputs(eval_data, data_dir, batch_size): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + data_dir: Path to the CIFAR-10 data directory. + batch_size: Number of images per batch. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + """ + if not eval_data: + filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) + for i in xrange(1, 6)] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN + else: + filenames = [os.path.join(data_dir, 'test_batch.bin')] + num_examples_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_EVAL + + for f in filenames: + if not tf.gfile.Exists(f): + raise ValueError('Failed to find file: ' + f) + + # Create a queue that produces the filenames to read. + filename_queue = tf.train.string_input_producer(filenames) + + # Read examples from files in the filename queue. + read_input = read_cifar10(filename_queue) + reshaped_image = tf.cast(read_input.uint8image, tf.float32) + + height = IMAGE_SIZE + width = IMAGE_SIZE + + # Image processing for evaluation. + # Crop the central [height, width] of the image. + resized_image = tf.image.resize_image_with_crop_or_pad(reshaped_image, + width, height) + + # Subtract off the mean and divide by the variance of the pixels. + float_image = tf.image.per_image_standardization(resized_image) + + # Set the shapes of tensors. + float_image.set_shape([height, width, 3]) + read_input.label.set_shape([1]) + + # Ensure that the random shuffling has good mixing properties. + min_fraction_of_examples_in_queue = 0.4 + min_queue_examples = int(num_examples_per_epoch * + min_fraction_of_examples_in_queue) + + # Generate a batch of images and labels by building up a queue of examples. + return _generate_image_and_label_batch(float_image, read_input.label, + min_queue_examples, batch_size, + shuffle=False) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1de869f6ef91791a235cfe545b3b3a9b734e72 --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_pruning.py @@ -0,0 +1,395 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Builds the CIFAR-10 network with additional variables to support pruning. + +Summary of available functions: + + # Compute input images and labels for training. If you would like to run + # evaluations, use inputs() instead. + inputs, labels = distorted_inputs() + + # Compute inference on the model inputs to make a prediction. + predictions = inference(inputs) + + # Compute the total loss of the prediction with respect to the labels. + loss = loss(predictions, labels) + + # Create a graph to run one step of training with respect to the loss. + train_op = train(loss, global_step) +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import sys +import tarfile + +from six.moves import urllib +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_input +from tensorflow.contrib.model_pruning.python import pruning + +# Global constants describing the CIFAR-10 data set. +IMAGE_SIZE = cifar10_input.IMAGE_SIZE +NUM_CLASSES = cifar10_input.NUM_CLASSES +NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN +NUM_EXAMPLES_PER_EPOCH_FOR_EVAL = cifar10_input.NUM_EXAMPLES_PER_EPOCH_FOR_EVAL +BATCH_SIZE = 128 +DATA_DIR = '/tmp/cifar10_data' + +# Constants describing the training process. +MOVING_AVERAGE_DECAY = 0.9999 # The decay to use for the moving average. +NUM_EPOCHS_PER_DECAY = 350.0 # Epochs after which learning rate decays. +LEARNING_RATE_DECAY_FACTOR = 0.1 # Learning rate decay factor. +INITIAL_LEARNING_RATE = 0.1 # Initial learning rate. + +# If a model is trained with multiple GPUs, prefix all Op names with tower_name +# to differentiate the operations. Note that this prefix is removed from the +# names of the summaries when visualizing a model. +TOWER_NAME = 'tower' + +DATA_URL = 'http://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz' + + +def _activation_summary(x): + """Helper to create summaries for activations. + + Creates a summary that provides a histogram of activations. + Creates a summary that measures the sparsity of activations. + + Args: + x: Tensor + Returns: + nothing + """ + # Remove 'tower_[0-9]/' from the name in case this is a multi-GPU training + # session. This helps the clarity of presentation on tensorboard. + tensor_name = re.sub('%s_[0-9]*/' % TOWER_NAME, '', x.op.name) + tf.summary.histogram(tensor_name + '/activations', x) + tf.summary.scalar(tensor_name + '/sparsity', + tf.nn.zero_fraction(x)) + + +def _variable_on_cpu(name, shape, initializer): + """Helper to create a Variable stored on CPU memory. + + Args: + name: name of the variable + shape: list of ints + initializer: initializer for Variable + + Returns: + Variable Tensor + """ + with tf.device('/cpu:0'): + dtype = tf.float32 + var = tf.get_variable(name, shape, initializer=initializer, dtype=dtype) + return var + + +def _variable_with_weight_decay(name, shape, stddev, wd): + """Helper to create an initialized Variable with weight decay. + + Note that the Variable is initialized with a truncated normal distribution. + A weight decay is added only if one is specified. + + Args: + name: name of the variable + shape: list of ints + stddev: standard deviation of a truncated Gaussian + wd: add L2Loss weight decay multiplied by this float. If None, weight + decay is not added for this Variable. + + Returns: + Variable Tensor + """ + dtype = tf.float32 + var = _variable_on_cpu( + name, + shape, + tf.truncated_normal_initializer(stddev=stddev, dtype=dtype)) + if wd is not None: + weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name='weight_loss') + tf.add_to_collection('losses', weight_decay) + return var + + +def distorted_inputs(): + """Construct distorted input for CIFAR training using the Reader ops. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir + """ + if not DATA_DIR: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin') + images, labels = cifar10_input.distorted_inputs( + data_dir=data_dir, batch_size=BATCH_SIZE) + return images, labels + + +def inputs(eval_data): + """Construct input for CIFAR evaluation using the Reader ops. + + Args: + eval_data: bool, indicating if one should use the train or eval data set. + + Returns: + images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size. + labels: Labels. 1D tensor of [batch_size] size. + + Raises: + ValueError: If no data_dir + """ + if not DATA_DIR: + raise ValueError('Please supply a data_dir') + data_dir = os.path.join(DATA_DIR, 'cifar-10-batches-bin') + images, labels = cifar10_input.inputs( + eval_data=eval_data, data_dir=data_dir, batch_size=BATCH_SIZE) + return images, labels + + +def inference(images): + """Build the CIFAR-10 model. + + Args: + images: Images returned from distorted_inputs() or inputs(). + + Returns: + Logits. + """ + # We instantiate all variables using tf.get_variable() instead of + # tf.Variable() in order to share variables across multiple GPU training runs. + # If we only ran this model on a single GPU, we could simplify this function + # by replacing all instances of tf.get_variable() with tf.Variable(). + # + # While instantiating conv and local layers, we add mask and threshold + # variables to the layer by calling the pruning.apply_mask() function. + # Note that the masks are applied only to the weight tensors + # conv1 + with tf.variable_scope('conv1') as scope: + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 3, 64], + stddev=5e-2, + wd=0.0) + + conv = tf.nn.conv2d( + images, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.0)) + pre_activation = tf.nn.bias_add(conv, biases) + conv1 = tf.nn.relu(pre_activation, name=scope.name) + _activation_summary(conv1) + + # pool1 + pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], + padding='SAME', name='pool1') + # norm1 + norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm1') + + # conv2 + with tf.variable_scope('conv2') as scope: + kernel = _variable_with_weight_decay('weights', + shape=[5, 5, 64, 64], + stddev=5e-2, + wd=0.0) + conv = tf.nn.conv2d( + norm1, pruning.apply_mask(kernel, scope), [1, 1, 1, 1], padding='SAME') + biases = _variable_on_cpu('biases', [64], tf.constant_initializer(0.1)) + pre_activation = tf.nn.bias_add(conv, biases) + conv2 = tf.nn.relu(pre_activation, name=scope.name) + _activation_summary(conv2) + + # norm2 + norm2 = tf.nn.lrn(conv2, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, + name='norm2') + # pool2 + pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], + strides=[1, 2, 2, 1], padding='SAME', name='pool2') + + # local3 + with tf.variable_scope('local3') as scope: + # Move everything into depth so we can perform a single matrix multiply. + reshape = tf.reshape(pool2, [BATCH_SIZE, -1]) + dim = reshape.get_shape()[1].value + weights = _variable_with_weight_decay('weights', shape=[dim, 384], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [384], tf.constant_initializer(0.1)) + local3 = tf.nn.relu( + tf.matmul(reshape, pruning.apply_mask(weights, scope)) + biases, + name=scope.name) + _activation_summary(local3) + + # local4 + with tf.variable_scope('local4') as scope: + weights = _variable_with_weight_decay('weights', shape=[384, 192], + stddev=0.04, wd=0.004) + biases = _variable_on_cpu('biases', [192], tf.constant_initializer(0.1)) + local4 = tf.nn.relu( + tf.matmul(local3, pruning.apply_mask(weights, scope)) + biases, + name=scope.name) + _activation_summary(local4) + + # linear layer(WX + b), + # We don't apply softmax here because + # tf.nn.sparse_softmax_cross_entropy_with_logits accepts the unscaled logits + # and performs the softmax internally for efficiency. + with tf.variable_scope('softmax_linear') as scope: + weights = _variable_with_weight_decay('weights', [192, NUM_CLASSES], + stddev=1/192.0, wd=0.0) + biases = _variable_on_cpu('biases', [NUM_CLASSES], + tf.constant_initializer(0.0)) + softmax_linear = tf.add( + tf.matmul(local4, pruning.apply_mask(weights, scope)), + biases, + name=scope.name) + _activation_summary(softmax_linear) + + return softmax_linear + + +def loss(logits, labels): + """Add L2Loss to all the trainable variables. + + Add summary for "Loss" and "Loss/avg". + Args: + logits: Logits from inference(). + labels: Labels from distorted_inputs or inputs(). 1-D tensor + of shape [batch_size] + + Returns: + Loss tensor of type float. + """ + # Calculate the average cross entropy loss across the batch. + labels = tf.cast(labels, tf.int64) + cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits( + labels=labels, logits=logits, name='cross_entropy_per_example') + cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy') + tf.add_to_collection('losses', cross_entropy_mean) + + # The total loss is defined as the cross entropy loss plus all of the weight + # decay terms (L2 loss). + return tf.add_n(tf.get_collection('losses'), name='total_loss') + + +def _add_loss_summaries(total_loss): + """Add summaries for losses in CIFAR-10 model. + + Generates moving average for all losses and associated summaries for + visualizing the performance of the network. + + Args: + total_loss: Total loss from loss(). + Returns: + loss_averages_op: op for generating moving averages of losses. + """ + # Compute the moving average of all individual losses and the total loss. + loss_averages = tf.train.ExponentialMovingAverage(0.9, name='avg') + losses = tf.get_collection('losses') + loss_averages_op = loss_averages.apply(losses + [total_loss]) + + # Attach a scalar summary to all individual losses and the total loss; do the + # same for the averaged version of the losses. + for l in losses + [total_loss]: + # Name each loss as '(raw)' and name the moving average version of the loss + # as the original loss name. + tf.summary.scalar(l.op.name + ' (raw)', l) + tf.summary.scalar(l.op.name, loss_averages.average(l)) + + return loss_averages_op + + +def train(total_loss, global_step): + """Train CIFAR-10 model. + + Create an optimizer and apply to all trainable variables. Add moving + average for all trainable variables. + + Args: + total_loss: Total loss from loss(). + global_step: Integer Variable counting the number of training steps + processed. + Returns: + train_op: op for training. + """ + # Variables that affect learning rate. + num_batches_per_epoch = NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN / BATCH_SIZE + decay_steps = int(num_batches_per_epoch * NUM_EPOCHS_PER_DECAY) + + # Decay the learning rate exponentially based on the number of steps. + lr = tf.train.exponential_decay(INITIAL_LEARNING_RATE, + global_step, + decay_steps, + LEARNING_RATE_DECAY_FACTOR, + staircase=True) + tf.summary.scalar('learning_rate', lr) + + # Generate moving averages of all losses and associated summaries. + loss_averages_op = _add_loss_summaries(total_loss) + + # Compute gradients. + with tf.control_dependencies([loss_averages_op]): + opt = tf.train.GradientDescentOptimizer(lr) + grads = opt.compute_gradients(total_loss) + + # Apply gradients. + apply_gradient_op = opt.apply_gradients(grads, global_step=global_step) + + # Add histograms for trainable variables. + for var in tf.trainable_variables(): + tf.summary.histogram(var.op.name, var) + + # Add histograms for gradients. + for grad, var in grads: + if grad is not None: + tf.summary.histogram(var.op.name + '/gradients', grad) + + # Track the moving averages of all trainable variables. + variable_averages = tf.train.ExponentialMovingAverage( + MOVING_AVERAGE_DECAY, global_step) + variables_averages_op = variable_averages.apply(tf.trainable_variables()) + + with tf.control_dependencies([apply_gradient_op, variables_averages_op]): + train_op = tf.no_op(name='train') + + return train_op + + +def maybe_download_and_extract(): + """Download and extract the tarball from Alex's website.""" + dest_directory = DATA_DIR + if not os.path.exists(dest_directory): + os.makedirs(dest_directory) + filename = DATA_URL.split('/')[-1] + filepath = os.path.join(dest_directory, filename) + if not os.path.exists(filepath): + def _progress(count, block_size, total_size): + sys.stdout.write('\r>> Downloading %s %.1f%%' % (filename, + float(count * block_size) / float(total_size) * 100.0)) + sys.stdout.flush() + filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) + print() + statinfo = os.stat(filepath) + print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') + + tarfile.open(filepath, 'r:gz').extractall(dest_directory) diff --git a/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py new file mode 100644 index 0000000000000000000000000000000000000000..a1064a3b6abe90f463184e977efb4de173e175cd --- /dev/null +++ b/tensorflow/contrib/model_pruning/examples/cifar10/cifar10_train.py @@ -0,0 +1,159 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A binary to train pruned CIFAR-10 using a single GPU. + +Accuracy: +cifar10_train.py achieves ~86% accuracy after 100K steps (256 epochs of +data) as judged by cifar10_eval.py when target sparsity in +cifar10_pruning_spec.pbtxt is set to zero + +Results: +Sparsity | Accuracy after 150K steps +-------- | ------------------------- +0% | 86% +50% | 86% +75% | TODO(suyoggupta) +90% | TODO(suyoggupta) +95% | 77% + +Usage: +Please see the tutorial and website for how to download the CIFAR-10 +data set, compile the program and train the model. + + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import datetime +import sys +import time + + +import tensorflow as tf + +from tensorflow.contrib.model_pruning.examples.cifar10 import cifar10_pruning as cifar10 +from tensorflow.contrib.model_pruning.python import pruning + +FLAGS = None + + +def train(): + """Train CIFAR-10 for a number of steps.""" + with tf.Graph().as_default(): + global_step = tf.contrib.framework.get_or_create_global_step() + + # Get images and labels for CIFAR-10. + images, labels = cifar10.distorted_inputs() + + # Build a Graph that computes the logits predictions from the + # inference model. + logits = cifar10.inference(images) + + # Calculate loss. + loss = cifar10.loss(logits, labels) + + # Build a Graph that trains the model with one batch of examples and + # updates the model parameters. + train_op = cifar10.train(loss, global_step) + + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) + + # Create a pruning object using the pruning hyperparameters + pruning_obj = pruning.Pruning(pruning_hparams, global_step=global_step) + + # Use the pruning_obj to add ops to the training graph to update the masks + # The conditional_mask_update_op will update the masks only when the + # training step is in [begin_pruning_step, end_pruning_step] specified in + # the pruning spec proto + mask_update_op = pruning_obj.conditional_mask_update_op() + + # Use the pruning_obj to add summaries to the graph to track the sparsity + # of each of the layers + pruning_obj.add_pruning_summaries() + + class _LoggerHook(tf.train.SessionRunHook): + """Logs loss and runtime.""" + + def begin(self): + self._step = -1 + + def before_run(self, run_context): + self._step += 1 + self._start_time = time.time() + return tf.train.SessionRunArgs(loss) # Asks for loss value. + + def after_run(self, run_context, run_values): + duration = time.time() - self._start_time + loss_value = run_values.results + if self._step % 10 == 0: + num_examples_per_step = 128 + examples_per_sec = num_examples_per_step / duration + sec_per_batch = float(duration) + + format_str = ('%s: step %d, loss = %.2f (%.1f examples/sec; %.3f ' + 'sec/batch)') + print(format_str % (datetime.datetime.now(), self._step, loss_value, + examples_per_sec, sec_per_batch)) + + with tf.train.MonitoredTrainingSession( + checkpoint_dir=FLAGS.train_dir, + hooks=[tf.train.StopAtStepHook(last_step=FLAGS.max_steps), + tf.train.NanTensorHook(loss), + _LoggerHook()], + config=tf.ConfigProto( + log_device_placement=FLAGS.log_device_placement)) as mon_sess: + while not mon_sess.should_stop(): + mon_sess.run(train_op) + # Update the masks + mon_sess.run(mask_update_op) + + +def main(argv=None): # pylint: disable=unused-argument + cifar10.maybe_download_and_extract() + if tf.gfile.Exists(FLAGS.train_dir): + tf.gfile.DeleteRecursively(FLAGS.train_dir) + tf.gfile.MakeDirs(FLAGS.train_dir) + train() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument( + '--train_dir', + type=str, + default='/tmp/cifar10_train', + help='Directory where to write event logs and checkpoint.') + parser.add_argument( + '--pruning_hparams', + type=str, + default='', + help="""Comma separated list of pruning-related hyperparameters""") + parser.add_argument( + '--max_steps', + type=int, + default=1000000, + help='Number of batches to run.') + parser.add_argument( + '--log_device_placement', + type=bool, + default=False, + help='Whether to log device placement.') + + FLAGS, unparsed = parser.parse_known_args() + tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) diff --git a/tensorflow/contrib/model_pruning/python/layers/core_layers.py b/tensorflow/contrib/model_pruning/python/layers/core_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ae60d8b1e189335ec93e2b8e50edcf8b41bc6725 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/core_layers.py @@ -0,0 +1,477 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Contains the core layer classes for model pruning and its functional aliases. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.layers import base +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import standard_ops + +MASK_COLLECTION = 'masks' +THRESHOLD_COLLECTION = 'thresholds' +MASKED_WEIGHT_COLLECTION = 'masked_weights' +WEIGHT_COLLECTION = 'kernel' +# The 'weights' part of the name is needed for the quantization library +# to recognize that the kernel should be quantized. +MASKED_WEIGHT_NAME = 'weights/masked_weight' + + +class _MaskedConv(base.Layer): + """Abstract nD convolution layer (private, used as implementation base). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. The weight tensor of this layer is masked. + If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + rank: An integer, the rank of the convolution, e.g. "2" for 2D convolution. + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of n integers, specifying the + length of the convolution window. + strides: An integer or tuple/list of n integers, + specifying the stride length of the convolution. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, ..., channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, ...)`. + dilation_rate: An integer or tuple/list of n integers, specifying + the dilation rate to use for dilated convolution. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any `strides` value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, no bias will + be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + rank, + filters, + kernel_size, + strides=1, + padding='valid', + data_format='channels_last', + dilation_rate=1, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(_MaskedConv, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.rank = rank + self.filters = filters + self.kernel_size = utils.normalize_tuple(kernel_size, rank, 'kernel_size') + self.strides = utils.normalize_tuple(strides, rank, 'strides') + self.padding = utils.normalize_padding(padding) + self.data_format = utils.normalize_data_format(data_format) + self.dilation_rate = utils.normalize_tuple(dilation_rate, rank, + 'dilation_rate') + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.kernel_regularizer = kernel_regularizer + self.bias_regularizer = bias_regularizer + self.input_spec = base.InputSpec(ndim=self.rank + 2) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + channel_axis = 1 if self.data_format == 'channels_first' else -1 + if input_shape[channel_axis].value is None: + raise ValueError('The channel dimension of the inputs ' + 'should be defined. Found `None`.') + input_dim = input_shape[channel_axis].value + kernel_shape = self.kernel_size + (input_dim, self.filters) + self.mask = self.add_variable( + name='mask', + shape=kernel_shape, + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + + self.kernel = self.add_variable( + name='kernel', + shape=kernel_shape, + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + trainable=True, + dtype=self.dtype) + + self.threshold = self.add_variable( + name='threshold', + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self.masked_kernel = math_ops.multiply(self.mask, self.kernel, + MASKED_WEIGHT_NAME) + + ops.add_to_collection(MASK_COLLECTION, self.mask) + ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel) + ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold) + ops.add_to_collection(WEIGHT_COLLECTION, self.kernel) + + if self.use_bias: + self.bias = self.add_variable( + name='bias', + shape=(self.filters,), + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + trainable=True, + dtype=self.dtype) + else: + self.bias = None + self.input_spec = base.InputSpec( + ndim=self.rank + 2, axes={channel_axis: input_dim}) + self.built = True + + def call(self, inputs): + outputs = nn.convolution( + input=inputs, + filter=self.masked_kernel, + dilation_rate=self.dilation_rate, + strides=self.strides, + padding=self.padding.upper(), + data_format=utils.convert_data_format(self.data_format, self.rank + 2)) + + if self.bias is not None: + if self.data_format == 'channels_first': + if self.rank == 1: + # nn.bias_add does not accept a 1D input tensor. + bias = array_ops.reshape(self.bias, (1, self.filters, 1)) + outputs += bias + if self.rank == 2: + outputs = nn.bias_add(outputs, self.bias, data_format='NCHW') + if self.rank == 3: + # As of Mar 2017, direct addition is significantly slower than + # bias_add when computing gradients. To use bias_add, we collapse Z + # and Y into a single dimension to obtain a 4D input tensor. + outputs_shape = outputs.shape.as_list() + outputs_4d = array_ops.reshape(outputs, [ + outputs_shape[0], outputs_shape[1], + outputs_shape[2] * outputs_shape[3], outputs_shape[4] + ]) + outputs_4d = nn.bias_add(outputs_4d, self.bias, data_format='NCHW') + outputs = array_ops.reshape(outputs_4d, outputs_shape) + else: + outputs = nn.bias_add(outputs, self.bias, data_format='NHWC') + + if self.activation is not None: + return self.activation(outputs) + return outputs + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape).as_list() + if self.data_format == 'channels_last': + space = input_shape[1:-1] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0]] + new_space + + [self.filters]) + else: + space = input_shape[2:] + new_space = [] + for i in range(len(space)): + new_dim = utils.conv_output_length( + space[i], + self.kernel_size[i], + padding=self.padding, + stride=self.strides[i], + dilation=self.dilation_rate[i]) + new_space.append(new_dim) + return tensor_shape.TensorShape([input_shape[0], self.filters] + + new_space) + + +class MaskedConv2D(_MaskedConv): + """2D convolution layer (e.g. spatial convolution over images). + + This layer creates a convolution kernel that is convolved + (actually cross-correlated) with the layer input to produce a tensor of + outputs. If `use_bias` is True (and a `bias_initializer` is provided), + a bias vector is created and added to the outputs. Finally, if + `activation` is not `None`, it is applied to the outputs as well. + + Arguments: + filters: Integer, the dimensionality of the output space (i.e. the number + of filters in the convolution). + kernel_size: An integer or tuple/list of 2 integers, specifying the + height and width of the 2D convolution window. + Can be a single integer to specify the same value for + all spatial dimensions. + strides: An integer or tuple/list of 2 integers, + specifying the strides of the convolution along the height and width. + Can be a single integer to specify the same value for + all spatial dimensions. + Specifying any stride value != 1 is incompatible with specifying + any `dilation_rate` value != 1. + padding: One of `"valid"` or `"same"` (case-insensitive). + data_format: A string, one of `channels_last` (default) or `channels_first`. + The ordering of the dimensions in the inputs. + `channels_last` corresponds to inputs with shape + `(batch, height, width, channels)` while `channels_first` corresponds to + inputs with shape `(batch, channels, height, width)`. + + dilation_rate: An integer or tuple/list of 2 integers, specifying + the dilation rate to use for dilated convolution. + Can be a single integer to specify the same value for + all spatial dimensions. + Currently, specifying any `dilation_rate` value != 1 is + incompatible with specifying any stride value != 1. + activation: Activation function. Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: An initializer for the convolution kernel. + bias_initializer: An initializer for the bias vector. If None, no bias will + be applied. + kernel_regularizer: Optional regularizer for the convolution kernel. + bias_regularizer: Optional regularizer for the bias vector. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: A string, the name of the layer. + """ + + def __init__(self, + filters, + kernel_size, + strides=(1, 1), + padding='valid', + data_format='channels_last', + dilation_rate=(1, 1), + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(MaskedConv2D, self).__init__( + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + trainable=trainable, + name=name, + **kwargs) + + +class MaskedFullyConnected(base.Layer): + """Fully-connected layer class with masked weights. + + This layer implements the operation: + `outputs = activation(inputs.kernel + bias)` + Where `activation` is the activation function passed as the `activation` + argument (if not `None`), `kernel` is a weights matrix created by the layer, + and `bias` is a bias vector created by the layer + (only if `use_bias` is `True`). + + Note: if the input to the layer has a rank greater than 2, then it is + flattened prior to the initial matrix multiply by `kernel`. + + Arguments: + units: Integer or Long, dimensionality of the output space. + activation: Activation function (callable). Set it to None to maintain a + linear activation. + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer function for the weight matrix. + bias_initializer: Initializer function for the bias. + kernel_regularizer: Regularizer function for the weight matrix. + bias_regularizer: Regularizer function for the bias. + activity_regularizer: Regularizer function for the output. + trainable: Boolean, if `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such cases. + reuse: Boolean, whether to reuse the weights of a previous layer + by the same name. + + Properties: + units: Python integer, dimensionality of the output space. + activation: Activation function (callable). + use_bias: Boolean, whether the layer uses a bias. + kernel_initializer: Initializer instance (or name) for the weight matrix. + bias_initializer: Initializer instance (or name) for the bias. + kernel_regularizer: Regularizer instance for the weight matrix (callable) + bias_regularizer: Regularizer instance for the bias (callable). + activity_regularizer: Regularizer instance for the output (callable) + kernel: Weight matrix (TensorFlow variable or tensor). + bias: Bias vector, if applicable (TensorFlow variable or tensor). + """ + + def __init__(self, + units, + activation=None, + use_bias=True, + kernel_initializer=None, + bias_initializer=init_ops.zeros_initializer(), + kernel_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + trainable=True, + name=None, + **kwargs): + super(MaskedFullyConnected, self).__init__( + trainable=trainable, + name=name, + activity_regularizer=activity_regularizer, + **kwargs) + self.units = units + self.activation = activation + self.use_bias = use_bias + self.kernel_initializer = kernel_initializer + self.bias_initializer = bias_initializer + self.kernel_regularizer = kernel_regularizer + self.bias_regularizer = bias_regularizer + self.input_spec = base.InputSpec(min_ndim=2) + + def build(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + if input_shape[-1].value is None: + raise ValueError('The last dimension of the inputs to `Dense` ' + 'should be defined. Found `None`.') + self.input_spec = base.InputSpec( + min_ndim=2, axes={-1: input_shape[-1].value}) + + self.kernel = self.add_variable( + 'kernel', + shape=[input_shape[-1].value, self.units], + initializer=self.kernel_initializer, + regularizer=self.kernel_regularizer, + dtype=self.dtype, + trainable=True) + + self.mask = self.add_variable( + name='mask', + shape=[input_shape[-1].value, self.units], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + + self.threshold = self.add_variable( + name='threshold', + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self.masked_kernel = math_ops.multiply(self.mask, self.kernel, + MASKED_WEIGHT_NAME) + + ops.add_to_collection(MASK_COLLECTION, self.mask) + ops.add_to_collection(MASKED_WEIGHT_COLLECTION, self.masked_kernel) + ops.add_to_collection(THRESHOLD_COLLECTION, self.threshold) + ops.add_to_collection(WEIGHT_COLLECTION, self.kernel) + + if self.use_bias: + self.bias = self.add_variable( + 'bias', + shape=[ + self.units, + ], + initializer=self.bias_initializer, + regularizer=self.bias_regularizer, + dtype=self.dtype, + trainable=True) + else: + self.bias = None + self.built = True + + def call(self, inputs): + inputs = ops.convert_to_tensor(inputs, dtype=self.dtype) + shape = inputs.get_shape().as_list() + output_shape = shape[:-1] + [self.units] + if len(output_shape) > 2: + # Broadcasting is required for the inputs. + outputs = standard_ops.tensordot(inputs, self.masked_kernel, + [[len(shape) - 1], [0]]) + # Reshape the output back to the original ndim of the input. + outputs.set_shape(output_shape) + else: + outputs = standard_ops.matmul(inputs, self.masked_kernel) + if self.use_bias: + outputs = nn.bias_add(outputs, self.bias) + if self.activation is not None: + return self.activation(outputs) # pylint: disable=not-callable + return outputs + + def _compute_output_shape(self, input_shape): + input_shape = tensor_shape.TensorShape(input_shape) + input_shape = input_shape.with_rank_at_least(2) + if input_shape[-1].value is None: + raise ValueError( + 'The innermost dimension of input_shape must be defined, but saw: %s' + % input_shape) + return input_shape[:-1].concatenate(self.units) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers.py b/tensorflow/contrib/model_pruning/python/layers/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..dfebb9a6794056dd43b0699ccbcc5797f2f172f7 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/layers.py @@ -0,0 +1,364 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tensorflow layers with added variables for parameter masking. + +Branched from tensorflow/contrib/layers/python/layers/layers.py +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import six + +from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.contrib.framework.python.ops import variables +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import utils +from tensorflow.contrib.model_pruning.python.layers import core_layers as core +from tensorflow.python.framework import ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables as tf_variables + + +def _model_variable_getter(getter, + name, + shape=None, + dtype=None, + initializer=None, + regularizer=None, + trainable=True, + collections=None, + caching_device=None, + partitioner=None, + rename=None, + use_resource=None, + **_): + """Getter that uses model_variable for compatibility with core layers.""" + short_name = name.split('/')[-1] + if rename and short_name in rename: + name_components = name.split('/') + name_components[-1] = rename[short_name] + name = '/'.join(name_components) + return variables.model_variable( + name, + shape=shape, + dtype=dtype, + initializer=initializer, + regularizer=regularizer, + collections=collections, + trainable=trainable, + caching_device=caching_device, + partitioner=partitioner, + custom_getter=getter, + use_resource=use_resource) + + +def _build_variable_getter(rename=None): + """Build a model variable getter that respects scope getter and renames.""" + + # VariableScope will nest the getters + def layer_variable_getter(getter, *args, **kwargs): + kwargs['rename'] = rename + return _model_variable_getter(getter, *args, **kwargs) + + return layer_variable_getter + + +def _add_variable_to_collections(variable, collections_set, collections_name): + """Adds variable (or all its parts) to all collections with that name.""" + collections = utils.get_variable_collections(collections_set, + collections_name) or [] + variables_list = [variable] + if isinstance(variable, tf_variables.PartitionedVariable): + variables_list = [v for v in variable] + for collection in collections: + for var in variables_list: + if var not in ops.get_collection(collection): + ops.add_to_collection(collection, var) + + +@add_arg_scope +def masked_convolution(inputs, + num_outputs, + kernel_size, + stride=1, + padding='SAME', + data_format=None, + rate=1, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds an 2D convolution followed by an optional batch_norm layer. + The layer creates a mask variable on top of the weight variable. The input to + the convolution operation is the elementwise multiplication of the mask + variable and the weigh + + It is required that 1 <= N <= 3. + + `convolution` creates a variable called `weights`, representing the + convolutional kernel, that is convolved (actually cross-correlated) with the + `inputs` to produce a `Tensor` of activations. If a `normalizer_fn` is + provided (such as `batch_norm`), it is then applied. Otherwise, if + `normalizer_fn` is None and a `biases_initializer` is provided then a `biases` + variable would be created and added the activations. Finally, if + `activation_fn` is not `None`, it is applied to the activations as well. + + Performs atrous convolution with input stride/dilation rate equal to `rate` + if a value > 1 for any dimension of `rate` is specified. In this case + `stride` values != 1 are not supported. + + Args: + inputs: A Tensor of rank N+2 of shape + `[batch_size] + input_spatial_shape + [in_channels]` if data_format does + not start with "NC" (default), or + `[batch_size, in_channels] + input_spatial_shape` if data_format starts + with "NC". + num_outputs: Integer, the number of output filters. + kernel_size: A sequence of N positive integers specifying the spatial + dimensions of of the filters. Can be a single integer to specify the same + value for all spatial dimensions. + stride: A sequence of N positive integers specifying the stride at which to + compute output. Can be a single integer to specify the same value for all + spatial dimensions. Specifying any `stride` value != 1 is incompatible + with specifying any `rate` value != 1. + padding: One of `"VALID"` or `"SAME"`. + data_format: A string or None. Specifies whether the channel dimension of + the `input` and output is the last dimension (default, or if `data_format` + does not start with "NC"), or the second dimension (if `data_format` + starts with "NC"). For N=1, the valid values are "NWC" (default) and + "NCW". For N=2, the valid values are "NHWC" (default) and "NCHW". + For N=3, the valid values are "NDHWC" (default) and "NCDHW". + rate: A sequence of N positive integers specifying the dilation rate to use + for atrous convolution. Can be a single integer to specify the same + value for all spatial dimensions. Specifying any `rate` value != 1 is + incompatible with specifying any `stride` value != 1. + activation_fn: Activation function. The default value is a ReLU function. + Explicitly set it to None to skip it and maintain a linear activation. + normalizer_fn: Normalization function to use instead of `biases`. If + `normalizer_fn` is provided then `biases_initializer` and + `biases_regularizer` are ignored and `biases` are not created nor added. + default set to None for no normalizer function + normalizer_params: Normalization function parameters. + weights_initializer: An initializer for the weights. + weights_regularizer: Optional regularizer for the weights. + biases_initializer: An initializer for the biases. If None skip biases. + biases_regularizer: Optional regularizer for the biases. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional list of collections for all the variables or + a dictionary containing a different list of collection per variable. + outputs_collections: Collection to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + scope: Optional scope for `variable_scope`. + + Returns: + A tensor representing the output of the operation. + + Raises: + ValueError: If `data_format` is invalid. + ValueError: Both 'rate' and `stride` are not uniformly 1. + """ + if data_format not in [None, 'NWC', 'NCW', 'NHWC', 'NCHW', 'NDHWC', 'NCDHW']: + raise ValueError('Invalid data_format: %r' % (data_format,)) + + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) + + with variable_scope.variable_scope( + scope, 'Conv', [inputs], reuse=reuse, + custom_getter=layer_variable_getter) as sc: + inputs = ops.convert_to_tensor(inputs) + input_rank = inputs.get_shape().ndims + + if input_rank == 3: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + elif input_rank == 4: + layer_class = core.MaskedConv2D + elif input_rank == 5: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + else: + raise ValueError('Sparse Convolution not supported for input with rank', + input_rank) + + if data_format is None or data_format == 'NHWC': + df = 'channels_last' + elif data_format == 'NCHW': + df = 'channels_first' + else: + raise ValueError('Unsupported data fromat', data_format) + + layer = layer_class( + filters=num_outputs, + kernel_size=kernel_size, + strides=stride, + padding=padding, + data_format=df, + dilation_rate=rate, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) + outputs = layer.apply(inputs) + + # Add variables to collections. + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') + if layer.use_bias: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') + + if normalizer_fn is not None: + normalizer_params = normalizer_params or {} + outputs = normalizer_fn(outputs, **normalizer_params) + + if activation_fn is not None: + outputs = activation_fn(outputs) + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) + + +masked_conv2d = masked_convolution + + +@add_arg_scope +def masked_fully_connected( + inputs, + num_outputs, + activation_fn=nn.relu, + normalizer_fn=None, + normalizer_params=None, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + variables_collections=None, + outputs_collections=None, + trainable=True, + scope=None): + """Adds a sparse fully connected layer. The weight matrix is masked. + + `fully_connected` creates a variable called `weights`, representing a fully + connected weight matrix, which is multiplied by the `inputs` to produce a + `Tensor` of hidden units. If a `normalizer_fn` is provided (such as + `batch_norm`), it is then applied. Otherwise, if `normalizer_fn` is + None and a `biases_initializer` is provided then a `biases` variable would be + created and added the hidden units. Finally, if `activation_fn` is not `None`, + it is applied to the hidden units as well. + + Note: that if `inputs` have a rank greater than 2, then `inputs` is flattened + prior to the initial matrix multiply by `weights`. + + Args: + inputs: A tensor of at least rank 2 and static value for the last dimension; + i.e. `[batch_size, depth]`, `[None, None, None, channels]`. + num_outputs: Integer or long, the number of output units in the layer. + activation_fn: Activation function. The default value is a ReLU function. + Explicitly set it to None to skip it and maintain a linear activation. + normalizer_fn: Normalization function to use instead of `biases`. If + `normalizer_fn` is provided then `biases_initializer` and + `biases_regularizer` are ignored and `biases` are not created nor added. + default set to None for no normalizer function + normalizer_params: Normalization function parameters. + weights_initializer: An initializer for the weights. + weights_regularizer: Optional regularizer for the weights. + biases_initializer: An initializer for the biases. If None skip biases. + biases_regularizer: Optional regularizer for the biases. + reuse: Whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + variables_collections: Optional list of collections for all the variables or + a dictionary containing a different list of collections per variable. + outputs_collections: Collection to add the outputs. + trainable: If `True` also add variables to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see tf.Variable). + scope: Optional scope for variable_scope. + + Returns: + The tensor variable representing the result of the series of operations. + + Raises: + ValueError: If x has rank less than 2 or if its last dimension is not set. + """ + if not isinstance(num_outputs, six.integer_types): + raise ValueError('num_outputs should be int or long, got %s.' % + (num_outputs,)) + + layer_variable_getter = _build_variable_getter({ + 'bias': 'biases', + 'kernel': 'weights' + }) + + with variable_scope.variable_scope( + scope, + 'fully_connected', [inputs], + reuse=reuse, + custom_getter=layer_variable_getter) as sc: + inputs = ops.convert_to_tensor(inputs) + layer = core.MaskedFullyConnected( + units=num_outputs, + activation=None, + use_bias=not normalizer_fn and biases_initializer, + kernel_initializer=weights_initializer, + bias_initializer=biases_initializer, + kernel_regularizer=weights_regularizer, + bias_regularizer=biases_regularizer, + activity_regularizer=None, + trainable=trainable, + name=sc.name, + dtype=inputs.dtype.base_dtype, + _scope=sc, + _reuse=reuse) + outputs = layer.apply(inputs) + + # Add variables to collections. + _add_variable_to_collections(layer.kernel, variables_collections, 'weights') + if layer.bias is not None: + _add_variable_to_collections(layer.bias, variables_collections, 'biases') + + # Apply normalizer function / layer. + if normalizer_fn is not None: + if not normalizer_params: + normalizer_params = {} + outputs = normalizer_fn(outputs, **normalizer_params) + + if activation_fn is not None: + outputs = activation_fn(outputs) + + return utils.collect_named_outputs(outputs_collections, + sc.original_name_scope, outputs) diff --git a/tensorflow/contrib/model_pruning/python/layers/layers_test.py b/tensorflow/contrib/model_pruning/python/layers/layers_test.py new file mode 100644 index 0000000000000000000000000000000000000000..97a2c978509e79f837a20595811a903a02b6a5eb --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/layers_test.py @@ -0,0 +1,139 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for imagingvision.intelligence.tensorflow.model_pruning.layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python.layers import core_layers +from tensorflow.contrib.model_pruning.python.layers import layers +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import test + + +class MaskedConvolutionLayerTest(test.TestCase): + + def setUp(self): + super(MaskedConvolutionLayerTest, self).setUp() + self.height, self.width = 7, 9 + + def testInvalidRank3(self): + input_tensor = array_ops.ones((self.height, self.width, 3)) + with self.assertRaisesRegexp(ValueError, 'rank'): + layers.masked_conv2d(input_tensor, 32, 3) + + def testInvalidRank5(self): + input_tensor = array_ops.ones((8, 8, self.height, self.width, 3)) + with self.assertRaisesRegexp(ValueError, 'rank'): + layers.masked_conv2d(input_tensor, 32, 3) + + def testSingleConvMaskAdded(self): + kernel_size = 3 + input_depth, output_depth = 8, 32 + input_tensor = array_ops.ones((8, self.height, self.width, input_depth)) + layers.masked_conv2d(input_tensor, output_depth, kernel_size) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), 1) + self.assertListEqual(masks[0].get_shape().as_list(), + [kernel_size, kernel_size, input_depth, output_depth]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), 1) + self.assertListEqual(masked_weight[0].get_shape().as_list(), + [kernel_size, kernel_size, input_depth, output_depth]) + + def testMultipleConvMaskAdded(self): + number_of_layers = 5 + + kernel_size = 3 + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, self.height, self.width, base_depth)) + + top_layer = input_tensor + + for ix in range(number_of_layers): + top_layer = layers.masked_conv2d(top_layer, base_depth + + (ix + 1) * depth_step, kernel_size) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masks[ix].get_shape().as_list(), [ + kernel_size, kernel_size, base_depth + ix * depth_step, + base_depth + (ix + 1) * depth_step + ]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masked_weight[ix].get_shape().as_list(), [ + kernel_size, kernel_size, base_depth + ix * depth_step, + base_depth + (ix + 1) * depth_step + ]) + + +class MaskedFullyConnectedLayerTest(test.TestCase): + + def testSingleFCMaskAdded(self): + input_depth, output_depth = 8, 32 + input_tensor = array_ops.ones((5, input_depth)) + layers.masked_fully_connected(input_tensor, output_depth) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), 1) + self.assertListEqual(masks[0].get_shape().as_list(), + [input_depth, output_depth]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), 1) + self.assertListEqual(masked_weight[0].get_shape().as_list(), + [input_depth, output_depth]) + + def testMultipleConvMaskAdded(self): + number_of_layers = 5 + + base_depth = 4 + depth_step = 7 + + input_tensor = array_ops.ones((8, base_depth)) + + top_layer = input_tensor + + for ix in range(number_of_layers): + top_layer = layers.masked_fully_connected(top_layer, base_depth + + (ix + 1) * depth_step) + + masks = ops.get_collection(core_layers.MASK_COLLECTION) + self.assertEqual(len(masks), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masks[ix].get_shape().as_list(), [ + base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step + ]) + + masked_weight = ops.get_collection(core_layers.MASKED_WEIGHT_COLLECTION) + self.assertEqual(len(masked_weight), number_of_layers) + for ix in range(number_of_layers): + self.assertListEqual(masked_weight[ix].get_shape().as_list(), [ + base_depth + ix * depth_step, base_depth + (ix + 1) * depth_step + ]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b050d25d00b298a20f7ce6abdda7c1d00db899 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py @@ -0,0 +1,348 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Module implementing RNN Cells with pruning. + +This module implements BasicLSTMCell and LSTMCell with pruning. +Code adapted from third_party/tensorflow/python/ops/rnn_cell_impl.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python.layers import core_layers +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import rnn_cell as tf_rnn + + +class MaskedBasicLSTMCell(tf_rnn.BasicLSTMCell): + """Basic LSTM recurrent network cell with pruning. + + Overrides the call method of tensorflow BasicLSTMCell and injects the weight + masks + + The implementation is based on: http://arxiv.org/abs/1409.2329. + + We add forget_bias (default: 1) to the biases of the forget gate in order to + reduce the scale of forgetting in the beginning of the training. + + It does not allow cell clipping, a projection layer, and does not + use peep-hole connections: it is the basic baseline. + + For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} + that follows. + """ + + def __init__(self, + num_units, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None, + name=None): + """Initialize the basic LSTM cell with pruning. + + Args: + num_units: int, The number of units in the LSTM cell. + forget_bias: float, The bias added to forget gates (see above). + Must set to `0.0` manually when restoring from CudnnLSTM-trained + checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. The latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + name: String, the name of the layer. Layers with the same name will + share weights, but to avoid mistakes we require reuse=True in such + cases. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(MaskedBasicLSTMCell, self).__init__( + num_units, + forget_bias=forget_bias, + state_is_tuple=state_is_tuple, + activation=activation, + reuse=reuse, + name=name) + + def build(self, inputs_shape): + # Call the build method of the parent class. + super(MaskedBasicLSTMCell, self).build(inputs_shape) + + self.built = False + + input_depth = inputs_shape[1].value + h_depth = self._num_units + self._mask = self.add_variable( + name="mask", + shape=[input_depth + h_depth, 4 * h_depth], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + self._threshold = self.add_variable( + name="threshold", + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self._masked_kernel = math_ops.multiply(self._mask, self._kernel, + core_layers.MASKED_WEIGHT_NAME) + if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): + ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) + ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, + self._masked_kernel) + ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) + ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + + self.built = True + + def call(self, inputs, state): + """Long short-term memory cell (LSTM) with masks for pruning. + + Args: + inputs: `2-D` tensor with shape `[batch_size, input_size]`. + state: An `LSTMStateTuple` of state tensors, each shaped + `[batch_size, self.state_size]`, if `state_is_tuple` has been set to + `True`. Otherwise, a `Tensor` shaped + `[batch_size, 2 * self.state_size]`. + + Returns: + A pair containing the new hidden state, and the new state (either a + `LSTMStateTuple` or a concatenated state, depending on + `state_is_tuple`). + """ + sigmoid = math_ops.sigmoid + one = constant_op.constant(1, dtype=dtypes.int32) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state + else: + c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) + + gate_inputs = math_ops.matmul( + array_ops.concat([inputs, h], 1), self._masked_kernel) + gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + i, j, f, o = array_ops.split( + value=gate_inputs, num_or_size_splits=4, axis=one) + + forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) + # Note that using `add` and `multiply` instead of `+` and `*` gives a + # performance improvement. So using those at the cost of readability. + add = math_ops.add + multiply = math_ops.multiply + new_c = add( + multiply(c, sigmoid(add(f, forget_bias_tensor))), + multiply(sigmoid(i), self._activation(j))) + new_h = multiply(self._activation(new_c), sigmoid(o)) + + if self._state_is_tuple: + new_state = tf_rnn.LSTMStateTuple(new_c, new_h) + else: + new_state = array_ops.concat([new_c, new_h], 1) + return new_h, new_state + + +class MaskedLSTMCell(tf_rnn.LSTMCell): + """LSTMCell with pruning. + + Overrides the call method of tensorflow LSTMCell and injects the weight masks. + Masks are applied to only the weight matrix of the LSTM and not the + projection matrix. + """ + + def __init__(self, + num_units, + use_peepholes=False, + cell_clip=None, + initializer=None, + num_proj=None, + proj_clip=None, + num_unit_shards=None, + num_proj_shards=None, + forget_bias=1.0, + state_is_tuple=True, + activation=None, + reuse=None): + """Initialize the parameters for an LSTM cell with masks for pruning. + + Args: + num_units: int, The number of units in the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + num_unit_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + num_proj_shards: Deprecated, will be removed by Jan. 2017. + Use a variable_scope partitioner instead. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. + state_is_tuple: If True, accepted and returned states are 2-tuples of + the `c_state` and `m_state`. If False, they are concatenated + along the column axis. This latter behavior will soon be deprecated. + activation: Activation function of the inner states. Default: `tanh`. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(MaskedLSTMCell, self).__init__( + num_units, + use_peepholes=use_peepholes, + cell_clip=cell_clip, + initializer=initializer, + num_proj=num_proj, + proj_clip=proj_clip, + num_unit_shards=num_unit_shards, + num_proj_shards=num_proj_shards, + forget_bias=forget_bias, + state_is_tuple=state_is_tuple, + activation=activation, + reuse=reuse) + + def build(self, inputs_shape): + # Call the build method of the parent class. + super(MaskedLSTMCell, self).build(inputs_shape) + + self.built = False + + input_depth = inputs_shape[1].value + h_depth = self._num_units + self._mask = self.add_variable( + name="mask", + shape=[input_depth + h_depth, 4 * h_depth], + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=self.dtype) + self._threshold = self.add_variable( + name="threshold", + shape=[], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=self.dtype) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + self._masked_kernel = math_ops.multiply(self._mask, self._kernel, + core_layers.MASKED_WEIGHT_NAME) + if self._mask not in ops.get_collection_ref(core_layers.MASK_COLLECTION): + ops.add_to_collection(core_layers.MASK_COLLECTION, self._mask) + ops.add_to_collection(core_layers.MASKED_WEIGHT_COLLECTION, + self._masked_kernel) + ops.add_to_collection(core_layers.THRESHOLD_COLLECTION, self._threshold) + ops.add_to_collection(core_layers.WEIGHT_COLLECTION, self._kernel) + + self.built = True + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, `[batch, num_units]. + state: if `state_is_tuple` is False, this must be a state Tensor, + `2-D, [batch, state_size]`. If `state_is_tuple` is True, this must be a + tuple of state Tensors, both `2-D`, with column sizes `c_state` and + `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch, output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + if self._state_is_tuple: + (c_prev, m_prev) = state + else: + c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units]) + m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj]) + + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = math_ops.matmul( + array_ops.concat([inputs, m_prev], 1), self._masked_kernel) + lstm_matrix = nn_ops.bias_add(lstm_matrix, self._bias) + + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + # Diagonal connections + if self._use_peepholes: + c = ( + sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev + + sigmoid(i + self._w_i_diag * c_prev) * self._activation(j)) + else: + c = ( + sigmoid(f + self._forget_bias) * c_prev + + sigmoid(i) * self._activation(j)) + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + self._w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + m = math_ops.matmul(m, self._proj_kernel) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = ( + tf_rnn.LSTMStateTuple(c, m) + if self._state_is_tuple else array_ops.concat([c, m], 1)) + return m, new_state diff --git a/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e85ae7b22a545045ec42ba86e9aed9cd7e6103f7 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/layers/rnn_cells_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for creating different number of masks in rnn_cells.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.contrib.model_pruning.python.layers import rnn_cells +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import rnn_cell as tf_rnn_cells +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class RnnCellsTest(test.TestCase): + + def setUp(self): + super(RnnCellsTest, self).setUp() + self.batch_size = 8 + self.dim = 10 + + def testMaskedBasicLSTMCell(self): + expected_num_masks = 1 + expected_num_rows = 2 * self.dim + expected_num_cols = 4 * self.dim + with self.test_session(): + inputs = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + c = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + h = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + state = tf_rnn_cells.LSTMStateTuple(c, h) + lstm_cell = rnn_cells.MaskedBasicLSTMCell(self.dim) + lstm_cell(inputs, state) + self.assertEqual(len(pruning.get_masks()), expected_num_masks) + self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) + self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) + self.assertEqual(len(pruning.get_weights()), expected_num_masks) + + for mask in pruning.get_masks(): + self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) + for weight in pruning.get_weights(): + self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) + + def testMaskedLSTMCell(self): + expected_num_masks = 1 + expected_num_rows = 2 * self.dim + expected_num_cols = 4 * self.dim + with self.test_session(): + inputs = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + c = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + h = variables.Variable( + random_ops.random_normal([self.batch_size, self.dim])) + state = tf_rnn_cells.LSTMStateTuple(c, h) + lstm_cell = rnn_cells.MaskedLSTMCell(self.dim) + lstm_cell(inputs, state) + self.assertEqual(len(pruning.get_masks()), expected_num_masks) + self.assertEqual(len(pruning.get_masked_weights()), expected_num_masks) + self.assertEqual(len(pruning.get_thresholds()), expected_num_masks) + self.assertEqual(len(pruning.get_weights()), expected_num_masks) + + for mask in pruning.get_masks(): + self.assertEqual(mask.shape, (expected_num_rows, expected_num_cols)) + for weight in pruning.get_weights(): + self.assertEqual(weight.shape, (expected_num_rows, expected_num_cols)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py new file mode 100644 index 0000000000000000000000000000000000000000..2b79c23cefe961b1c4056d41b5fcc0a0521efec6 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -0,0 +1,188 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Wrapper around tf-slim's training code contrib/slim/python/slim/learning.py +to support training of pruned models + +******************************************************************* +* A simple working training script with support for model pruning * +******************************************************************* + + # Load data and create the model: + images, labels = LoadData(...) + predictions = MyModel(images) + + # Define the loss: + slim.losses.log_loss(predictions, labels) + total_loss = slim.losses.get_total_loss() + + # Define the optimizer: + optimizer = tf.train.MomentumOptimizer(FLAGS.learning_rate, FLAGS.momentum) + + # Create the train_op + train_op = slim.learning.create_train_op(total_loss, optimizer) + + # Set up sparsity + sparsity = pruning.setup_gradual_sparsity(self.global_step) + + # Create mask update op + mask_update_op = pruning.add_mask_update_ip(sparsity) + + # Run training. + learning.train(train_op, + my_log_dir, + mask_update_op) + see contrib/slim/python/slim/learning.py for additional examples +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib import slim as _slim + +_USE_DEFAULT = 0 +train_step = _slim.learning.train_step + + +def train(train_op, + logdir, + mask_update_op, + train_step_fn=train_step, + train_step_kwargs=_USE_DEFAULT, + log_every_n_steps=1, + graph=None, + master='', + is_chief=True, + global_step=None, + number_of_steps=None, + init_op=_USE_DEFAULT, + init_feed_dict=None, + local_init_op=_USE_DEFAULT, + init_fn=None, + ready_op=_USE_DEFAULT, + summary_op=_USE_DEFAULT, + save_summaries_secs=600, + summary_writer=_USE_DEFAULT, + startup_delay_steps=0, + saver=None, + save_interval_secs=600, + sync_optimizer=None, + session_config=None, + trace_every_n_steps=None): + """Wrapper around tf-slim's train function. + + Runs a training loop using a TensorFlow supervisor. + When the sync_optimizer is supplied, gradient updates are applied + synchronously. Otherwise, gradient updates are applied asynchronous. + + Args: + train_op: A `Tensor` that, when executed, will apply the gradients and + return the loss value. + logdir: The directory where training logs are written to. If None, model + checkpoints and summaries will not be written. + mask_update_op: Operation that upon execution updates the weight masks and + thresholds. + train_step_fn: The function to call in order to execute a single gradient + step. The function must have take exactly four arguments: the current + session, the `train_op` `Tensor`, a global step `Tensor` and a dictionary. + train_step_kwargs: A dictionary which is passed to the `train_step_fn`. By + default, two `Boolean`, scalar ops called "should_stop" and "should_log" + are provided. + log_every_n_steps: The frequency, in terms of global steps, that the loss + and global step and logged. + graph: The graph to pass to the supervisor. If no graph is supplied the + default graph is used. + master: The address of the tensorflow master. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. + global_step: The `Tensor` representing the global step. If left as `None`, + then slim.variables.get_or_create_global_step() is used. + number_of_steps: The max number of gradient steps to take during training, + as measured by 'global_step': training will stop if global_step is + greater than 'number_of_steps'. If the value is left as None, training + proceeds indefinitely. + init_op: The initialization operation. If left to its default value, then + the session is initialized by calling `tf.global_variables_initializer()`. + init_feed_dict: A feed dictionary to use when executing the `init_op`. + local_init_op: The local initialization operation. If left to its default + value, then the session is initialized by calling + `tf.local_variables_initializer()` and `tf.tables_initializer()`. + init_fn: An optional callable to be executed after `init_op` is called. The + callable must accept one argument, the session being initialized. + ready_op: Operation to check if the model is ready to use. If left to its + default value, then the session checks for readiness by calling + `tf.report_uninitialized_variables()`. + summary_op: The summary operation. + save_summaries_secs: How often, in seconds, to save summaries. + summary_writer: `SummaryWriter` to use. Can be `None` + to indicate that no summaries should be written. If unset, we + create a SummaryWriter. + startup_delay_steps: The number of steps to wait for before beginning. Note + that this must be 0 if a sync_optimizer is supplied. + saver: Saver to save checkpoints. If None, a default one will be created + and used. + save_interval_secs: How often, in seconds, to save the model to `logdir`. + sync_optimizer: an instance of tf.train.SyncReplicasOptimizer, or a list of + them. If the argument is supplied, gradient updates will be synchronous. + If left as `None`, gradient updates will be asynchronous. + session_config: An instance of `tf.ConfigProto` that will be used to + configure the `Session`. If left as `None`, the default will be used. + trace_every_n_steps: produce and save a `Timeline` in Chrome trace format + and add it to the summaries every `trace_every_n_steps`. If None, no trace + information will be produced or saved. + + Returns: + the value of the loss function after training. + + Raises: + ValueError: if `train_op` is empty or if `startup_delay_steps` is + non-zero when `sync_optimizer` is supplied, if `number_of_steps` is + negative, or if `trace_every_n_steps` is not `None` and no `logdir` is + provided. + """ + + def train_step_with_pruning_fn(sess, train_op, global_step, + train_step_kwargs): + total_loss, should_stop = train_step_fn(sess, train_op, global_step, + train_step_kwargs) + sess.run(mask_update_op) + return total_loss, should_stop + + total_loss, _ = _slim.learning.train( + train_op, + logdir, + train_step_fn=train_step_with_pruning_fn, + train_step_kwargs=train_step_kwargs, + log_every_n_steps=log_every_n_steps, + graph=graph, + master=master, + is_chief=is_chief, + global_step=global_step, + number_of_steps=number_of_steps, + init_op=init_op, + init_feed_dict=init_feed_dict, + local_init_op=local_init_op, + init_fn=init_fn, + ready_op=ready_op, + summary_op=summary_op, + save_summaries_secs=save_summaries_secs, + summary_writer=summary_writer, + startup_delay_steps=startup_delay_steps, + saver=saver, + save_interval_secs=save_interval_secs, + sync_optimizer=sync_optimizer, + session_config=session_config, + trace_every_n_steps=trace_every_n_steps) + + return total_loss diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py new file mode 100644 index 0000000000000000000000000000000000000000..42d91a71fde41d8681d7a0c439d6c49325730418 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -0,0 +1,585 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Helper functions to add support for magnitude-based model pruning. + + # Adds variables and ops to the graph to enable + # elementwise masking of weights + apply_mask(weights) + + # Returns a list containing the sparsity of each of the weight tensors + get_weight_sparsity() + + # Returns a list of all the masked weight tensorflow variables + get_masked_weights() + + # Returns a list of all the mask tensorflow variables + get_masks() + + # Returns a list of all the thresholds + get_thresholds() + + # Returns a list of all the weight tensors that have been masked + get_weights() + + The Pruning class uses a proto (defined in pruning.proto) to set up the + parameters for a pruning specification. Here's a typical usage: + + # Initialize a pruning spec from a proto + pruning_spec = '/tmp/pruning.pb' + p = Pruning(pruning_spec) + + # Add mask update ops to the graph + mask_update_op = p.conditional_mask_update_op() + + # Add the summaries + p.add_pruning_summaries() + + # Run the op + session.run(mask_update_op) + + # An object of the pruning also accepts externally defined sparsity: + sparsity = tf.Variable(0.5, name = "ConstantSparsity") + pruning_spec = '/tmp/pruning.pb' + p = Pruning(pruning_spec, sparsity=sparsity) + +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.model_pruning.python.layers import core_layers as core +from tensorflow.contrib.training.python.training import hparam +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.summary import summary +from tensorflow.python.training import training_util + +_MASK_COLLECTION = core.MASK_COLLECTION +_THRESHOLD_COLLECTION = core.THRESHOLD_COLLECTION +_MASKED_WEIGHT_COLLECTION = core.MASKED_WEIGHT_COLLECTION +_WEIGHT_COLLECTION = core.WEIGHT_COLLECTION +_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME + + +def _weight_mask_variable(var, scope): + """Create a mask for the weights. + + This function adds a variable 'mask' to the graph. + + Args: + var: the weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + the mask variable of the same size and shape as var, initialized to all 1s. + """ + with variable_scope.variable_scope(scope): + mask = variable_scope.get_variable( + 'mask', + var.get_shape(), + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=var.dtype) + return mask + + +def _weight_threshold_variable(var, scope): + """Create a scalar threshold for the weights. + + This function adds a variable + 'threshold' to the graph. + + Args: + var: The weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + a scalar threshold variable initialized to 0. + """ + with variable_scope.variable_scope(scope): + threshold = variable_scope.get_variable( + 'threshold', [], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=var.dtype) + return threshold + + +def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None): + """Return histogram of values. + + Given the tensor `values`, this operation returns a rank 1 histogram counting + the number of entries in `values` that fell into every bin. The bins are + equal width and determined by the arguments `value_range` and `nbins`. + + Args: + values: Numeric `Tensor`. + value_range: Shape [2] `Tensor` of same `dtype` as `values`. + values <= value_range[0] will be mapped to hist[0], + values >= value_range[1] will be mapped to hist[-1]. + nbins: Scalar `int32 Tensor`. Number of histogram bins. + dtype: dtype for returned histogram. + name: A name for this operation (defaults to 'histogram'). + + Returns: + A 1-D `Tensor` holding histogram of values. + + """ + with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: + values = ops.convert_to_tensor(values, name='values') + values = gen_array_ops.reshape(values, [-1]) + value_range = ops.convert_to_tensor(value_range, name='value_range') + nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins') + nbins_float = math_ops.cast(nbins, values.dtype) + + # Map tensor values that fall within value_range to [0, 1]. + scaled_values = math_ops.truediv( + values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') + + # map tensor values within the open interval value_range to {0,.., nbins-1}, + # values outside the open interval will be zero or less, or nbins or more. + indices = math_ops.floor(nbins_float * scaled_values, name='indices') + + # Clip edge cases (e.g. value = value_range[1]) or "outliers." + indices = math_ops.cast( + clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32) + + return math_ops.unsorted_segment_sum( + array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) + + +def _determine_partitioned_axis(partitioned_variable): + partitioned_axis = 0 + concatenated_variable_shape = partitioned_variable.get_shape() + for partition in partitioned_variable: + partition_shape = partition.get_shape() + maybe_partitioned_axis = np.less(partition_shape, + concatenated_variable_shape) + # Sanity check: make sure number of partitioned axis == 1 + if np.count_nonzero(maybe_partitioned_axis) != 1: + raise ValueError('Number of partitioned axes %s not equal to 1' % + np.count_nonzero(maybe_partitioned_axis)) + partitioned_axis = np.where(maybe_partitioned_axis)[0][0] + return partitioned_axis + + +def _variable_assign(var, new_value): + return state_ops.assign(var, new_value, name=var.op.name + '_assign') + + +def _partitioned_variable_assign(partitioned_var, new_value): + """Assign op for partitioned variables. + + Args: + partitioned_var: A partitioned tensotflow variable + new_value: Value to be assigned to the variable var + + Returns: + A tensorflow op that groups the assign ops for each of the variable slices + """ + # Determine which axis was used to partition the variable. Currently + # tensorflow allows partitioning variable only along 1 axis. + axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis( + partitioned_var) + + partition_sizes = np.array( + [partition.get_shape()[axis] for partition in partitioned_var]) + new_partitioned_values = array_ops.split( + new_value, + ops.convert_to_tensor(partition_sizes, dtype=np.int32), + axis=axis) + op_list = [] + for partition in partitioned_var: + op_list.append( + _variable_assign(partition, new_partitioned_values[len(op_list)])) + return control_flow_ops.group( + *op_list, name=partitioned_var.name + '_group_assign') + + +def apply_mask(x, scope=''): + """Apply mask to a given weight tensor. + + Args: + x: Input weight tensor + scope: The current variable scope. Defaults to "" + Returns: + Tensor representing masked_weights + """ + + mask = _weight_mask_variable(x, scope) + threshold = _weight_threshold_variable(x, scope) + # Add masked_weights in the weights namescope so as to make it easier + # for the quantization library to add quant ops. + masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) + + # Make sure the mask for a given variable are not added multiple times to the + # collection. This is particularly important when applying mask to RNN's + # weight variables + if mask not in ops.get_collection_ref(_MASK_COLLECTION): + ops.add_to_collection(_THRESHOLD_COLLECTION, threshold) + ops.add_to_collection(_MASK_COLLECTION, mask) + ops.add_to_collection(_MASKED_WEIGHT_COLLECTION, masked_weights) + ops.add_to_collection(_WEIGHT_COLLECTION, x) + return masked_weights + + +def get_masked_weights(): + return ops.get_collection(_MASKED_WEIGHT_COLLECTION) + + +def get_masks(): + return ops.get_collection(_MASK_COLLECTION) + + +def get_thresholds(): + return ops.get_collection(_THRESHOLD_COLLECTION) + + +def get_weights(): + return ops.get_collection(_WEIGHT_COLLECTION) + + +def get_weight_sparsity(): + """Get sparsity of the weights. + + Args: + None + + Returns: + A list containing the sparsity of each of the weight tensors + """ + masks = get_masks() + return [nn_impl.zero_fraction(mask) for mask in masks] + + +def get_pruning_hparams(): + """Get a tf.HParams object with the default values for the hyperparameters. + + name: string + name of the pruning specification. Used for adding summaries and ops under + a common tensorflow name_scope + begin_pruning_step: integer + the global step at which to begin pruning + end_pruning_step: integer + the global step at which to terminate pruning. Defaults to -1 implying + that pruning continues till the training stops + do_not_prune: list of strings + list of layers that are not pruned + threshold_decay: float + the decay factor to use for exponential decay of the thresholds + pruning_frequency: integer + How often should the masks be updated? (in # of global_steps) + nbins: integer + number of bins to use for histogram computation + initial_sparsity: float + initial sparsity value + target_sparsity: float + target sparsity value + sparsity_function_begin_step: integer + the global step at this which the gradual sparsity function begins to + take effect + sparsity_function_end_step: integer + the global step used as the end point for the gradual sparsity function + sparsity_function_exponent: float + exponent = 1 is linearly varying sparsity between initial and final. + exponent > 1 varies more slowly towards the end than the beginning + + We use the following sparsity function: + + num_steps = (sparsity_function_end_step - + sparsity_function_begin_step)/pruning_frequency + sparsity(step) = (initial_sparsity - target_sparsity)* + [1-step/(num_steps -1)]**exponent + target_sparsity + + Args: + None + + Returns: + tf.HParams object initialized to default values + + """ + return hparam.HParams( + name='model_pruning', + begin_pruning_step=0, + end_pruning_step=-1, + do_not_prune=[''], + threshold_decay=0.9, + pruning_frequency=10, + nbins=255, + initial_sparsity=0, + target_sparsity=0.5, + sparsity_function_begin_step=0, + sparsity_function_end_step=100, + sparsity_function_exponent=3) + + +class Pruning(object): + + def __init__(self, + spec=None, + global_step=None, + sparsity=None, + partitioner=None): + """Set up the specification for model pruning. + + If a spec is provided, the sparsity is set up based on the sparsity_function + in the spec. The effect of sparsity_function is overridden if the sparsity + variable is passed to the constructor. This enables setting up arbitrary + sparsity profiles externally and passing it to this pruning functions. + + Args: + spec: Pruning spec as defined in pruning.proto + global_step: A tensorflow variable that is used while setting up the + sparsity function + sparsity: A tensorflow scalar variable storing the sparsity + partitioner: The tensorflow partitioner function used to distribute + parameters across shards + """ + # Pruning specification + self._spec = spec if spec else get_pruning_hparams() + + # A tensorflow variable that tracks the sparsity function. + # If not provided as input, the graph must already contain the global_step + # variable before calling this constructor. + self._global_step = self._setup_global_step(global_step) + + # Stores the tensorflow sparsity variable. + # Built using self._setup_sparsity() or provided externally + self._sparsity = sparsity if sparsity else self._setup_sparsity() + + # Stores the partitioner function uses to partition variables across tasks/ + self._partitioner = partitioner + + # List of tensorflow assignments ops for new masks and thresholds + self._assign_ops = [] + + # Tensorflow variable keeping track of the last global step when the masks + # were updated + self._last_update_step = self._setup_last_update_step() + + def _setup_global_step(self, global_step): + graph_global_step = global_step + if graph_global_step is None: + graph_global_step = training_util.get_global_step() + + return math_ops.cast(graph_global_step, np.int32) + + def _setup_sparsity(self): + begin_step = self._spec.sparsity_function_begin_step + end_step = self._spec.sparsity_function_end_step + initial_sparsity = self._spec.initial_sparsity + target_sparsity = self._spec.target_sparsity + exponent = self._spec.sparsity_function_exponent + + if begin_step >= end_step: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % + (begin_step, end_step)) + + with ops.name_scope(self._spec.name): + p = math_ops.minimum(1.0, + math_ops.maximum( + 0.0, + math_ops.div( + math_ops.cast(self._global_step - begin_step, + np.float32), + end_step - begin_step))) + sparsity = math_ops.add( + math_ops.multiply(initial_sparsity - target_sparsity, + math_ops.pow(1 - p, exponent)), + target_sparsity, + name='sparsity') + + return sparsity + + def _setup_last_update_step(self): + with variable_scope.variable_scope(self._spec.name) as scope: + try: + last_update_step = variable_scope.get_variable( + 'last_mask_update_step', [], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=np.int32) + except ValueError: + scope.reuse_variables() + last_update_step = variable_scope.get_variable( + 'last_mask_update_step', dtype=np.int32) + return last_update_step + + def _exists_in_do_not_prune_list(self, tensor_name): + do_not_prune_list = self._spec.do_not_prune + if not do_not_prune_list[0]: + return False + for layer_name in do_not_prune_list: + if tensor_name.find(layer_name) != -1: + return True + + return False + + def _update_mask(self, weights, threshold): + """Updates the mask for a given weight tensor. + + This functions first computes the cdf of the weight tensor, and estimates + the threshold value such that 'desired_sparsity' fraction of weights + have magnitude less than the threshold. + + Args: + weights: The weight tensor that needs to be masked. + threshold: The current threshold value. The function will compute a new + threshold and return the exponential moving average using the current + value of threshold + + Returns: + new_threshold: The new value of the threshold based on weights, and + desired_sparsity + new_mask: A n-D numpy array containing 0 or 1 to indicate which of the + values in weights falls below the threshold + + Raises: + ValueError: if sparsity is not defined + """ + if self._sparsity is None: + raise ValueError('Sparsity variable undefined') + + with ops.name_scope(weights.op.name + '_pruning_ops'): + abs_weights = math_ops.abs(weights) + max_value = math_ops.reduce_max(abs_weights) + histogram = _histogram( + abs_weights, [0.0, max_value], + nbins=self._spec.nbins, + dtype=np.float32) + + cdf = math_ops.cumsum(histogram) + norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram)) + current_threshold = math_ops.multiply( + math_ops.div( + math_ops.reduce_sum( + math_ops.cast( + math_ops.less(norm_cdf, self._sparsity), np.float32)), + float(self._spec.nbins)), max_value) + + smoothed_threshold = math_ops.add_n([ + math_ops.multiply(current_threshold, 1 - self._spec.threshold_decay), + math_ops.multiply(threshold, self._spec.threshold_decay) + ]) + new_mask = math_ops.cast( + math_ops.greater(abs_weights, smoothed_threshold), np.float32) + return smoothed_threshold, new_mask + + def _get_mask_assign_ops(self): + # Make sure the assignment ops have not already been added to the list + if self._assign_ops: + raise ValueError( + 'Assign op list not empty. _get_mask_assign_ops() called twice?') + + masks = get_masks() + weights = get_weights() + thresholds = get_thresholds() + + if len(masks) != len(thresholds): + raise ValueError( + 'Number of masks %s and number of thresholds %s mismatch' % + (len(masks), len(thresholds))) + + for index, mask in enumerate(masks): + threshold = thresholds[index] + weight = weights[index] if self._partitioner is None else weights[ + index].as_tensor() + + if self._spec.do_not_prune: + if self._exists_in_do_not_prune_list(mask.name): + continue + + new_threshold, new_mask = self._update_mask(weight, threshold) + self._assign_ops.append(_variable_assign(threshold, new_threshold)) + self._assign_ops.append( + _variable_assign(mask, new_mask) if self._partitioner is None else + _partitioned_variable_assign(mask, new_mask)) + + def mask_update_op(self): + with ops.name_scope(self._spec.name): + if not self._assign_ops: + self._get_mask_assign_ops() + with ops.control_dependencies([ + state_ops.assign( + self._last_update_step, + self._global_step, + name='last_mask_update_step_assign') + ]): + with ops.control_dependencies(self._assign_ops): + logging.info('Updating masks.') + return control_flow_ops.no_op('mask_update') + + def conditional_mask_update_op(self): + + def maybe_update_masks(): + with ops.name_scope(self._spec.name): + is_step_within_pruning_range = math_ops.logical_and( + math_ops.greater_equal(self._global_step, + self._spec.begin_pruning_step), + # If end_pruning_step is negative, keep pruning forever! + math_ops.logical_or( + math_ops.less_equal(self._global_step, + self._spec.end_pruning_step), + math_ops.less(self._spec.end_pruning_step, 0))) + is_pruning_step = math_ops.less_equal( + math_ops.add(self._last_update_step, self._spec.pruning_frequency), + self._global_step) + return math_ops.logical_and(is_step_within_pruning_range, + is_pruning_step) + + def mask_update_op(): + return self.mask_update_op() + + def no_update_op(): + return control_flow_ops.no_op() + + return control_flow_ops.cond(maybe_update_masks(), mask_update_op, + no_update_op) + + def add_pruning_summaries(self): + """Adds summaries for this pruning spec. + + Args: none + + Returns: none + """ + with ops.name_scope(self._spec.name + '_summaries'): + summary.scalar('sparsity', self._sparsity) + summary.scalar('last_mask_update_step', self._last_update_step) + masks = get_masks() + thresholds = get_thresholds() + for index, mask in enumerate(masks): + if not self._exists_in_do_not_prune_list(mask.name): + summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask)) + summary.scalar(thresholds[index].op.name + '/threshold', + thresholds[index]) + + def print_hparams(self): + logging.info(self._spec.to_json()) diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c23fd649ce1fc72a2e8d516bfa3750b7ced1b111 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -0,0 +1,162 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for the key functions in pruning library.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import partitioned_variables +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import training_util + + +class PruningHParamsTest(test.TestCase): + PARAM_LIST = [ + "name=test", "threshold_decay=0.9", "pruning_frequency=10", + "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100", + "target_sparsity=0.9" + ] + TEST_HPARAMS = ",".join(PARAM_LIST) + + def setUp(self): + super(PruningHParamsTest, self).setUp() + # Add global step variable to the graph + self.global_step = training_util.get_or_create_global_step() + # Add sparsity + self.sparsity = variables.Variable(0.5, name="sparsity") + # Parse hparams + self.pruning_hparams = pruning.get_pruning_hparams().parse( + self.TEST_HPARAMS) + + def testInit(self): + p = pruning.Pruning(self.pruning_hparams) + self.assertEqual(p._spec.name, "test") + self.assertAlmostEqual(p._spec.threshold_decay, 0.9) + self.assertEqual(p._spec.pruning_frequency, 10) + self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"]) + self.assertEqual(p._spec.sparsity_function_end_step, 100) + self.assertAlmostEqual(p._spec.target_sparsity, 0.9) + + def testInitWithExternalSparsity(self): + with self.test_session(): + p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) + variables.global_variables_initializer().run() + sparsity = p._sparsity.eval() + self.assertAlmostEqual(sparsity, 0.5) + + def testInitWithVariableReuse(self): + with self.test_session(): + p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) + p_copy = pruning.Pruning( + spec=self.pruning_hparams, sparsity=self.sparsity) + variables.global_variables_initializer().run() + sparsity = p._sparsity.eval() + self.assertAlmostEqual(sparsity, 0.5) + self.assertEqual(p._sparsity.eval(), p_copy._sparsity.eval()) + + +class PruningTest(test.TestCase): + + def setUp(self): + super(PruningTest, self).setUp() + self.global_step = training_util.get_or_create_global_step() + + def testCreateMask2D(self): + width = 10 + height = 20 + with self.test_session(): + weights = variables.Variable( + random_ops.random_normal([width, height], stddev=1), name="weights") + masked_weights = pruning.apply_mask(weights, + variable_scope.get_variable_scope()) + variables.global_variables_initializer().run() + weights_val = weights.eval() + masked_weights_val = masked_weights.eval() + self.assertAllEqual(weights_val, masked_weights_val) + + def testUpdateSingleMask(self): + with self.test_session() as session: + weights = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + masked_weights = pruning.apply_mask(weights) + sparsity = variables.Variable(0.5, name="sparsity") + p = pruning.Pruning(sparsity=sparsity) + p._spec.threshold_decay = 0.0 + mask_update_op = p.mask_update_op() + variables.global_variables_initializer().run() + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) + session.run(mask_update_op) + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + + def testPartitionedVariableMasking(self): + partitioner = partitioned_variables.variable_axis_size_partitioner(40) + with self.test_session() as session: + with variable_scope.variable_scope("", partitioner=partitioner): + sparsity = variables.Variable(0.5, name="Sparsity") + weights = variable_scope.get_variable( + "weights", initializer=math_ops.linspace(1.0, 100.0, 100)) + masked_weights = pruning.apply_mask( + weights, scope=variable_scope.get_variable_scope()) + p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner) + p._spec.threshold_decay = 0.0 + mask_update_op = p.mask_update_op() + variables.global_variables_initializer().run() + masked_weights_val = masked_weights.eval() + session.run(mask_update_op) + masked_weights_val = masked_weights.eval() + self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + + def testConditionalMaskUpdate(self): + param_list = [ + "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" + ] + test_spec = ",".join(param_list) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + weights = variables.Variable( + math_ops.linspace(1.0, 100.0, 100), name="weights") + masked_weights = pruning.apply_mask(weights) + sparsity = variables.Variable(0.00, name="sparsity") + # Set up pruning + p = pruning.Pruning(pruning_hparams, sparsity=sparsity) + p._spec.threshold_decay = 0.0 + mask_update_op = p.conditional_mask_update_op() + sparsity_val = math_ops.linspace(0.0, 0.9, 10) + increment_global_step = state_ops.assign_add(self.global_step, 1) + non_zero_count = [] + with self.test_session() as session: + variables.global_variables_initializer().run() + for i in range(10): + session.run(state_ops.assign(sparsity, sparsity_val[i])) + session.run(mask_update_op) + session.run(increment_global_step) + non_zero_count.append(np.count_nonzero(masked_weights.eval())) + # Weights pruned at steps 0,2,4,and,6 + expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40] + self.assertAllEqual(expected_non_zero_count, non_zero_count) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/mpi_collectives/__init__.py b/tensorflow/contrib/mpi_collectives/__init__.py index b94f7b0a353c4c3c698a927d8718bb5b490872cb..9ed16a6f078a506b60fd14f4356ff65a0a692203 100644 --- a/tensorflow/contrib/mpi_collectives/__init__.py +++ b/tensorflow/contrib/mpi_collectives/__init__.py @@ -194,7 +194,7 @@ class DistributedOptimizer(tf.train.Optimizer): See Optimizer.compute_gradients() for more info. - In DistributedOptimizer, compute_gradients() is overriden to also + In DistributedOptimizer, compute_gradients() is overridden to also allreduce the gradients before returning them. """ gradients = (super(DistributedOptimizer, self) diff --git a/tensorflow/contrib/nccl/BUILD b/tensorflow/contrib/nccl/BUILD index d6508362b8bf01468a43b26d6a0d0c9807b5967e..ed9fb64b954cc3dfec06936b479226a7def90008 100644 --- a/tensorflow/contrib/nccl/BUILD +++ b/tensorflow/contrib/nccl/BUILD @@ -71,10 +71,15 @@ tf_kernel_library( "kernels/nccl_manager.cc", "kernels/nccl_manager.h", "kernels/nccl_ops.cc", + "kernels/nccl_rewrite.cc", ], deps = [ + "//tensorflow/core:core_cpu", "//tensorflow/core:framework", "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib", + "//tensorflow/core:proto_text", + "//tensorflow/core:stream_executor", "@nccl_archive//:nccl", ], alwayslink = 1, @@ -110,7 +115,11 @@ tf_custom_op_py_library( deps = [ ":nccl_ops", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:device", + "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:util", + "//tensorflow/python/eager:context", ], ) diff --git a/tensorflow/contrib/nccl/kernels/nccl_ops.cc b/tensorflow/contrib/nccl/kernels/nccl_ops.cc index 4eb52492dbcc386941029709631314634c1c9be1..266d4f6f0de0274dca2bfc9022503f09b0ca7d42 100644 --- a/tensorflow/contrib/nccl/kernels/nccl_ops.cc +++ b/tensorflow/contrib/nccl/kernels/nccl_ops.cc @@ -15,8 +15,6 @@ limitations under the License. #if GOOGLE_CUDA -#include -#include #include #include "src/nccl.h" @@ -24,6 +22,7 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { +namespace { // Base class for all communicator ops that use nccl. // @@ -134,7 +133,7 @@ class NcclReduceSendKernel : public NcclReduceOpBase { compute_stream, &c->input(0), std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceSend").Device(DEVICE_GPU), NcclReduceSendKernel); // To execute a single reduce, this kernel is called once for one devices, and @@ -166,7 +165,7 @@ class NcclReduceRecvKernel : public NcclReduceOpBase { private: ncclRedOp_t reduction_op_; }; -REGISTER_KERNEL_BUILDER(Name("NcclReduceRecv").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclReduceRecv").Device(DEVICE_GPU), NcclReduceRecvKernel); // To execute a single broadcast, this kernel is called once for one device, and @@ -191,7 +190,7 @@ class NcclBroadcastSendKernel : public NcclAsyncOpBase { std::move(actual_done)); } }; -REGISTER_KERNEL_BUILDER(Name("NcclBroadcastSend").Device(DEVICE_GPU), +REGISTER_KERNEL_BUILDER(Name("_NcclBroadcastSend").Device(DEVICE_GPU), NcclBroadcastSendKernel); // To execute a single broadcast, this kernel is called once for all but one of @@ -206,7 +205,7 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { const Tensor& shape_t = c->input(0); TensorShape shape; OP_REQUIRES_OK_ASYNC( - c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); + c, TensorShapeUtils::MakeShape(shape_t.vec(), &shape), done); Tensor* out_t; OP_REQUIRES_OK_ASYNC(c, c->allocate_output(0, shape, &out_t), done); @@ -224,9 +223,24 @@ class NcclBroadcastRecvKernel : public NcclAsyncOpBase { } }; REGISTER_KERNEL_BUILDER( - Name("NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), + Name("_NcclBroadcastRecv").Device(DEVICE_GPU).HostMemory("shape"), NcclBroadcastRecvKernel); +// Define stub kernels for the ops that get replaced post placement. +class NcclStubKernel : public AsyncOpKernel { + public: + explicit NcclStubKernel(OpKernelConstruction* c) : AsyncOpKernel(c) {} + void ComputeAsync(OpKernelContext* c, DoneCallback done) override { + c->SetStatus(errors::Unimplemented( + "This op should be replaced during graph optimization.")); + done(); + } +}; +REGISTER_KERNEL_BUILDER(Name("NcclBroadcast").Device(DEVICE_GPU), + NcclStubKernel); +REGISTER_KERNEL_BUILDER(Name("NcclReduce").Device(DEVICE_GPU), NcclStubKernel); + +} // namespace } // namespace tensorflow #endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc new file mode 100644 index 0000000000000000000000000000000000000000..a4de46a93fab1dfe93b47f2789cc533bc447e43a --- /dev/null +++ b/tensorflow/contrib/nccl/kernels/nccl_rewrite.cc @@ -0,0 +1,276 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#if GOOGLE_CUDA + +#include +#include + +#include "tensorflow/core/common_runtime/optimization_registry.h" +#include "tensorflow/core/graph/node_builder.h" + +namespace tensorflow { +namespace { + +// Replaces NcclReduce node with _NcclReduceRecv reusing one input of same +// device, adds one _NcclReduceSend for each other input. +Status ReplaceReduce(Graph* graph, Node* node) { + string reduction; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "reduction", &reduction)); + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int num_devices = node->num_inputs(); + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("reduction", reduction) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + std::vector control_inputs; + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + control_inputs.push_back(edge->src()); + } + } + std::vector out_nodes; + for (const auto& edge : node->out_edges()) { + out_nodes.emplace_back(edge->dst(), edge->dst_input()); + } + int recv_dev = node->assigned_device_name_index(); + NodeBuilder recv_builder = + make_builder("_NcclReduceRecv", "Recv").ControlInputs(control_inputs); + bool recv_input_set = false; + int send_counter = 0; + for (const auto& edge : node->in_edges()) { + Node* src_node = edge->src(); + if (edge->IsControlEdge()) { + continue; + } + int send_dev = src_node->assigned_device_name_index(); + if (!recv_input_set && send_dev == recv_dev) { + recv_builder.Input(src_node); + recv_input_set = true; + continue; + } + auto send_builder = make_builder("_NcclReduceSend", + strings::StrCat("Send_", ++send_counter)) + .Input(src_node) + .ControlInputs(control_inputs); + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + // Send nodes don't have any outputs and therefore have no data dependencies + // to the outputs of the graph. We add a control dependency to the receive + // node so that those 'dangling' nodes are run. + // TODO(b/67027412): Avoid these cross-device control edges. + for (const auto& out_node : out_nodes) { + graph->AddControlEdge(send_node, out_node.node); + } + } + if (!recv_input_set) { + return errors::InvalidArgument( + "No input tensor uses the same device as the NcclReduce op"); + } + Node* recv_node = nullptr; + TF_RETURN_IF_ERROR(recv_builder.Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + graph->RemoveNode(node); + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(recv_node, out_node.node); + } else { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + return Status::OK(); +} + +TensorProto TensorFromShape(const TensorShapeProto& shape) { + TensorProto result; + result.set_dtype(DT_INT32); + for (const auto& dim : shape.dim()) { + result.add_int_val(dim.size()); + } + result.mutable_tensor_shape()->add_dim()->set_size(shape.dim_size()); + return result; +} + +// Replaces NcclBroadcast node with _NcclBroadcastSend, connects the input to +// all outputs of same device, adds one _NcclBroadcastRecv for each other output +// device. +Status ReplaceBroadcast(Graph* graph, Node* node) { + DataType dtype; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "T", &dtype)); + int send_dev = node->assigned_device_name_index(); + int num_devices = 0; // Number of distinct devices, incremented below. + std::vector recv_index_map; // Map device name index to stable index. + + // Map device name index to nodes that take the broadcast as input. + std::vector> out_nodes_map; + for (const auto& edge : node->out_edges()) { + int dst_dev = edge->IsControlEdge() + ? send_dev + : edge->dst()->assigned_device_name_index(); + if (out_nodes_map.size() <= dst_dev) { + out_nodes_map.resize(dst_dev + 1); + recv_index_map.resize(dst_dev + 1); + } + auto it = out_nodes_map.begin() + dst_dev; + if (it->empty()) { + recv_index_map[dst_dev] = num_devices; + ++num_devices; + } + it->emplace_front(NodeBuilder::NodeOut(edge->dst(), edge->dst_input())); + } + + if (num_devices <= 1) { + // Only one participating device, skip NCCL op. + const Edge* in_edge = nullptr; + TF_RETURN_IF_ERROR(node->input_edge(0, &in_edge)); + Node* in_node = in_edge->src(); + int in_index = in_edge->src_output(); + graph->RemoveNode(node); + for (const auto& out_nodes : out_nodes_map) { + for (const auto& out_node : out_nodes) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(in_node, out_node.node); + } else { + graph->AddEdge(in_node, in_index, out_node.node, out_node.index); + } + } + } + return Status::OK(); + } + + string shared_name = node->name(); + auto make_builder = [&](StringPiece op_name, StringPiece suffix) { + return NodeBuilder(strings::StrCat(shared_name, suffix), op_name) + .Attr("num_devices", num_devices) + .Attr("shared_name", shared_name) + .Attr("T", dtype); + }; + + // Create broadcast send node and replace the original broadcast node. + NodeBuilder::NodeOut in_node; + NodeBuilder send_builder = make_builder("_NcclBroadcastSend", "Send"); + for (const auto& edge : node->in_edges()) { + if (edge->IsControlEdge()) { + send_builder.ControlInput(edge->src()); + } else { + in_node = NodeBuilder::NodeOut(edge->src(), edge->src_output()); + send_builder.Input(in_node); + } + } + Node* send_node = nullptr; + TF_RETURN_IF_ERROR(send_builder.Finalize(graph, &send_node)); + send_node->set_assigned_device_name_index(send_dev); + + TensorShapeProto shape_proto; + TF_RETURN_IF_ERROR(GetNodeAttr(node->attrs(), "shape", &shape_proto)); + + // Delete the original node before reconnecting to outputs. + graph->RemoveNode(node); + + // Connect all outputs on the device of broadcast send. + for (const auto& out_node : out_nodes_map[send_dev]) { + if (out_node.index == Graph::kControlSlot) { + graph->AddControlEdge(send_node, out_node.node); + } else { + graph->AddEdge(in_node.node, in_node.index, out_node.node, + out_node.index); + // Add control edge so send node is run. + graph->AddControlEdge(send_node, out_node.node); + } + } + out_nodes_map[send_dev].clear(); + + TensorProto tensor_proto = TensorFromShape(shape_proto); + bool is_fully_defined = TensorShape(shape_proto).IsFullyDefined(); + string shape_name = strings::StrCat(in_node.node->name(), "/Shape"); + Node* shape_node = nullptr; + if (!is_fully_defined) { + NodeBuilder shape_builder(shape_name, "Shape"); + shape_builder.Input(in_node).Attr("out_type", DT_INT32).Attr("T", dtype); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(send_dev); + } + + // For all other devices, create a broadcast receive and connect outputs. + for (int recv_dev = 0; recv_dev < out_nodes_map.size(); ++recv_dev) { + if (out_nodes_map[recv_dev].empty()) { + continue; + } + int recv_index = recv_index_map[recv_dev]; + if (is_fully_defined) { + // If the shape is fully defined, define one const node per device. + NodeBuilder shape_builder(strings::StrCat(shape_name, recv_index), + "Const"); + shape_builder.Attr("value", tensor_proto).Attr("dtype", DT_INT32); + TF_RETURN_IF_ERROR(shape_builder.Finalize(graph, &shape_node)); + shape_node->set_assigned_device_name_index(recv_dev); + } + Node* recv_node; + TF_RETURN_IF_ERROR( + make_builder("_NcclBroadcastRecv", strings::StrCat("Recv_", recv_index)) + .Input(shape_node) + .Finalize(graph, &recv_node)); + recv_node->set_assigned_device_name_index(recv_dev); + for (const auto& out_node : out_nodes_map[recv_dev]) { + graph->AddEdge(recv_node, 0, out_node.node, out_node.index); + } + } + + return Status::OK(); +} + +// Replaces occurrences of Nccl{Reduce, Broadcast}Input/Output with their +// _Nccl...Send/Recv counterparts and removes data dependencies between them. +class NcclReplacePass : public GraphOptimizationPass { + public: + Status Run(const GraphOptimizationPassOptions& options) override { + if (options.graph == nullptr) { + return Status::OK(); + } + Graph* graph = options.graph->get(); + if (graph == nullptr) { + return errors::Internal( + "NCCL replacement should happen before partitioning and a " + "graph should be available."); + } + // Find reduction and broadcast ops and replace them with Send/Recv ops. + for (Node* node : graph->op_nodes()) { + StringPiece type = node->type_string(); + if (!type.starts_with("Nccl")) { + continue; + } + if (type == "NcclReduce") { + TF_RETURN_IF_ERROR(ReplaceReduce(graph, node)); + } + if (type == "NcclBroadcast") { + TF_RETURN_IF_ERROR(ReplaceBroadcast(graph, node)); + } + } + return Status::OK(); + } +}; +REGISTER_OPTIMIZATION(OptimizationPassRegistry::POST_PLACEMENT, 0, + NcclReplacePass); + +} // namespace +} // namespace tensorflow + +#endif // GOOGLE_CUDA diff --git a/tensorflow/contrib/nccl/ops/nccl_ops.cc b/tensorflow/contrib/nccl/ops/nccl_ops.cc index 532c79c24cc9596af580ee3faf463aecbc59bb07..8eb804c2e988f313ba1b340217cae20f1f5502c7 100644 --- a/tensorflow/contrib/nccl/ops/nccl_ops.cc +++ b/tensorflow/contrib/nccl/ops/nccl_ops.cc @@ -45,7 +45,28 @@ num_devices: The number of devices participating in this reduction. shared_name: Identifier that shared between ops of the same reduction. )doc"); -REGISTER_OP("NcclReduceSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclReduceSend and _NcclReduceRecv during graph optimization stage. +REGISTER_OP("NcclReduce") + .Input("input: num_devices * T") + .Output("data: T") + .Attr("reduction: {'min', 'max', 'prod', 'sum'}") + .Attr("T: {float, float64, int32, int64}") + .Attr("num_devices: int") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Reduces `input` from `num_devices` using `reduction` to a single device. + +The graph should be constructed so that all inputs have a valid device +assignment, and the op itself is assigned one of these devices. + +input: The input to the reduction. +data: the value of the reduction across all `num_devices` devices. +reduction: the reduction operation to perform. + )doc"); + +REGISTER_OP("_NcclReduceSend") .Input("input: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") .Attr("T: {float, float64, int32, int64}") @@ -54,19 +75,20 @@ REGISTER_OP("NcclReduceSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. +Replacement node for NcclReduce. +Reduces `input` to the NcclReduceRecv op registered in the same `shared_name`. The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclReduceRecv") +REGISTER_OP("_NcclReduceRecv") .Input("input: T") .Output("data: T") .Attr("reduction: {'min', 'max', 'prod', 'sum'}") @@ -76,21 +98,42 @@ REGISTER_OP("NcclReduceRecv") .SetIsStateful() .SetShapeFn(shape_inference::UnchangedShape) .Doc(R"doc( +Replacement node for NcclReduce. + Reduces 'input' from this op and the NcclReduceSend ops registered in the same `shared_name`. - The graph should be constructed so that 'num_devices-1' devices run -`NcclReduceSend` and one device runs NcclReduceRecv op with shared_name value +`_NcclReduceSend` and one device runs _NcclReduceRecv op with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the reduction +input: The input to the reduction. data: The reduced data received from this op and the NcclReduceSend op. reduction: the reduction operation to perform. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same reduce. )doc"); -REGISTER_OP("NcclBroadcastSend") +// Note: This op has no kernel implementation, but is replaced by +// _NcclBroadcastSend and _NcclBroadcastRecv during graph optimization stage. +REGISTER_OP("NcclBroadcast") + .Input("input: T") + .Output("output: T") + .Attr("T: {float, float64, int32, int64}") + .Attr("shape: shape") + .SetIsStateful() + .SetShapeFn(shape_inference::UnchangedShape) + .Doc(R"doc( +Sends `input` to all devices that are connected to the output. + +The graph should be constructed so that all ops connected to the output have a +valid device assignment, and the op itself is assigned one of these devices. + +input: The input to the broadcast. +output: The same as input. +shape: The shape of the input tensor. + )doc"); + +REGISTER_OP("_NcclBroadcastSend") .Input("input: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -98,19 +141,21 @@ REGISTER_OP("NcclBroadcastSend") .SetIsStateful() .SetShapeFn(shape_inference::NoOutputs) .Doc(R"doc( -Sends `input` to the NcclBroadcastRecv ops registered in the same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends `input` to the _NcclBroadcastRecv ops registered in the same +`shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. -input: The input to the broadcast +input: The input to the broadcast. num_devices: The number of devices participating in this reduction. shared_name: Identifier that is shared between ops of the same broadcast. )doc"); -REGISTER_OP("NcclBroadcastRecv") - .Input("shape: int64") +REGISTER_OP("_NcclBroadcastRecv") + .Input("shape: int32") .Output("output: T") .Attr("T: {float, float64, int32, int64}") .Attr("num_devices: int") @@ -123,11 +168,12 @@ REGISTER_OP("NcclBroadcastRecv") return Status::OK(); }) .Doc(R"doc( -Sends data of shape `shape` from the NcclBroadcastSend op registered in the -same `shared_name`. +Replacement node for NcclBroadcast. -The graph should be constructed so that one device runs `NcclBroadcastSend` and -`num_devices-1` devices run NcclBroadcastRecv ops with shared_name value `c`. +Sends data of shape `shape` from the _NcclBroadcastSend op registered in the +same `shared_name`. +The graph should be constructed so that one device runs `_NcclBroadcastSend` and +`num_devices-1` devices run _NcclBroadcastRecv ops with shared_name value `c`. Failure to do so will cause the graph execution to fail to complete. shape: The shape of the output. diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops.py b/tensorflow/contrib/nccl/python/ops/nccl_ops.py index 906d9f948acf212dce1dbbbf9ec7c60c30f389b1..8dc038b9ac992de7db8b762e3697c6693099e192 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops.py @@ -23,9 +23,7 @@ from tensorflow.contrib.nccl.ops import gen_nccl_ops from tensorflow.contrib.util import loader from tensorflow.python.eager import context from tensorflow.python.framework import device -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops from tensorflow.python.platform import resource_loader _nccl_ops_so = loader.load_op_library( @@ -64,13 +62,13 @@ def _all_sum_grad(op, grad): LookupError: If `reduction` is not `sum`. """ if op.get_attr('reduction') != 'sum': - raise LookupError('No gradient defined for NcclAllReduce except all_sum.') + raise LookupError('No gradient defined for NcclAllReduce except sum.') - _check_device_assignment(grad) + _check_device(grad, expected=op.device) num_devices = op.get_attr('num_devices') shared_name = op.get_attr('shared_name') + '_grad' - with ops.device(grad.device): + with ops.device(op.device): return gen_nccl_ops.nccl_all_reduce( input=grad, reduction='sum', @@ -129,7 +127,7 @@ def all_max(tensors): return _apply_all_reduce('max', tensors) -def reduce_sum(tensors, dst_device): +def reduce_sum(tensors): """Returns a tensor with the reduce sum across `tensors`. The computation is done with a reduce operation, so only one tensor is @@ -138,54 +136,76 @@ def reduce_sum(tensors, dst_device): Args: tensors: The input tensors across which to sum; must be assigned to GPU devices. - dst_device: The device of the returned tensor. Returns: - A tensor containing the sum of the input tensors, with the device of the - tensor being `dst_device`. + A tensor containing the sum of the input tensors. + + Raises: + LookupError: If context is not currently using a GPU device. + """ + return _apply_reduce('sum', tensors) + + +@ops.RegisterGradient('NcclReduce') +def _reduce_sum_grad(op, grad): + """The gradients for input `Operation` of `reduce_sum`. + + Args: + op: The `sum send` `Operation` that we are differentiating. + grad: Gradient with respect to the output of the `reduce_sum` op. + + Returns: + The gradient with respect to the input of `reduce_sum` op. + + Raises: + LookupError: If the reduction attribute of op is not `sum`. """ - return _apply_reduce('sum', tensors, dst_device) + if op.get_attr('reduction') != 'sum': + raise LookupError('No gradient defined for NcclReduce except sum.') + _check_device(grad, expected=op.device) + with ops.device(op.device): + result = gen_nccl_ops.nccl_broadcast(input=grad, shape=grad.shape) -def broadcast(src_tensor, dst_devices): - """Returns a list of tensors on `dst_devices`, each with value `tensor`. + return [result] * len(op.inputs) - The computation is done with a broadcast nccl operation, so if only some of - the returned tensors and src_tensor are evaluated then the computation will - hang. + +def broadcast(tensor): + """Returns a tensor that can be efficiently transferred to other devices. Args: - src_tensor: The tensor to send; must be assigned to a GPU device. - dst_devices: The GPU devices to receive the sent tensor. + tensor: The tensor to send; must be assigned to a GPU device. Returns: - An `Operation` to send the `src_tensor`, and a list of tensors, each with - the value of `src_tensor`, where the device of tensor i is `dst_devices[i]`. + A tensor with the value of `src_tensor`, which can be used as input to + ops on other GPU devices. """ - if not dst_devices: - raise ValueError('Must pass >0 dst_devices to broadcast') _check_graph_mode() - _check_device_assignment(src_tensor) + _check_device(tensor) - shape = array_ops.shape(src_tensor, out_type=dtypes.int64) - num_devices = len(dst_devices) + 1 - shared_name = _get_shared_name() + with ops.device(tensor.device): + return gen_nccl_ops.nccl_broadcast(input=tensor, shape=tensor.shape) - with ops.device(src_tensor.device): - send = gen_nccl_ops.nccl_broadcast_send( - input=src_tensor, num_devices=num_devices, shared_name=shared_name) - - recvs = [] - for d in dst_devices: - with ops.device(d): - recvs.append( - gen_nccl_ops.nccl_broadcast_recv( - shape=shape, - T=src_tensor.dtype, - num_devices=num_devices, - shared_name=shared_name)) - return send, recvs +@ops.RegisterGradient('NcclBroadcast') +def _broadcast_grad(op, accumulated_grad): + """The gradients for input `Operation` of `broadcast`. + + Args: + op: The `broadcast send` `Operation` that we are differentiating. + accumulated_grad: Accumulated gradients with respect to the output of the + `broadcast` op. + + Returns: + Gradients with respect to the input of `broadcast`. + """ + # Grab inputs of accumulated_grad and replace accumulation with reduce_sum. + grads = [t for t in accumulated_grad.op.inputs] + for t in grads: + _check_device(t) + + with ops.device(op.device): + return gen_nccl_ops.nccl_reduce(input=grads, reduction='sum') def _apply_all_reduce(reduction, tensors): @@ -198,7 +218,7 @@ def _apply_all_reduce(reduction, tensors): res = [] for t in tensors: - _check_device_assignment(t) + _check_device(t) with ops.device(t.device): res.append( gen_nccl_ops.nccl_all_reduce( @@ -210,40 +230,20 @@ def _apply_all_reduce(reduction, tensors): return res -def _apply_reduce(reduction, tensors, dst_device): +def _apply_reduce(reduction, tensors): """Helper function for reduce_* functions.""" if not tensors: raise ValueError('Must pass >0 tensors to reduce operations') - if not dst_device: - raise ValueError('Must pass dst_device to reduce operations') _check_graph_mode() + for t in tensors: + _check_device(t) + result = gen_nccl_ops.nccl_reduce(input=tensors, reduction=reduction) try: - recv_index = next(i for i, t in enumerate(tensors) - if t.device == dst_device) + next(t for t in tensors if t.device == result.device) except StopIteration: - raise ValueError('One of the tensors must be assigned to dst_device') - shared_name = _get_shared_name() - - sends = [] - for t in tensors[:recv_index] + tensors[recv_index + 1:]: - _check_device_assignment(t) - with ops.device(t.device): - sends.append( - gen_nccl_ops.nccl_reduce_send( - input=t, - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name)) - - with ops.device(dst_device): - recv = gen_nccl_ops.nccl_reduce_recv( - input=tensors[recv_index], - reduction=reduction, - num_devices=len(tensors), - shared_name=shared_name) - - return recv, sends + raise ValueError('One input tensor must be assigned to current device') + return result _lock = threading.Lock() @@ -259,9 +259,11 @@ def _get_shared_name(): return 'c%s' % val -def _check_device_assignment(tensor): +def _check_device(tensor, expected=None): if not device.canonical_name(tensor.device): raise ValueError('Device assignment required for nccl collective ops') + if expected and expected != tensor.device: + raise ValueError('Expected device %s, got %s' % (expected, tensor.device)) def _check_graph_mode(): diff --git a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py index 96d67723a0ad197436a12924bd2b4ecb73eee4cb..0b13e3595e36b609468f459d9179f8e9f5c1e055 100644 --- a/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py +++ b/tensorflow/contrib/nccl/python/ops/nccl_ops_test.py @@ -22,8 +22,10 @@ from functools import partial import numpy as np from tensorflow.contrib import nccl +from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gradients from tensorflow.python.platform import test @@ -36,27 +38,30 @@ def _DeviceTensors(tensors, devices): def _NcclAllReduce(nccl_fun, tensors, devices): - return nccl_fun(_DeviceTensors(tensors, devices)), [] + return nccl_fun(_DeviceTensors(tensors, devices)) def _NcclReduce(nccl_fun, tensors, devices): - d_tensors = _DeviceTensors(tensors, devices) receiver = np.random.randint(0, len(devices)) - received_tensor, send_ops = nccl_fun(d_tensors, devices[receiver]) - return [received_tensor], send_ops + with ops.device(devices[receiver]): + return [nccl_fun(_DeviceTensors(tensors, devices))] def _NcclBroadcast(tensors, devices): sender = np.random.randint(0, len(devices)) - d_tensor = _DeviceTensors(tensors[0:1], devices[sender:sender + 1])[0] - other_devices = devices[:sender] + devices[sender + 1:] - send_op, received_tensors = nccl.broadcast(d_tensor, other_devices) - return received_tensors, [send_op] + with ops.device(devices[sender]): + tensor = array_ops.identity(tensors[0]) + broadcast = nccl.broadcast(tensor) + return _DeviceTensors([broadcast] * len(devices), devices) class NcclTestCase(test.TestCase): - def _Test(self, nccl_reduce, numpy_fn): + def _Test(self, + nccl_reduce, + numpy_fn, + device_sets=(['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], + ['/device:GPU:1', '/device:GPU:0'])): """Tests that nccl_reduce does the same as reduction with numpy_fn. Args: @@ -65,6 +70,7 @@ class NcclTestCase(test.TestCase): reduction. numpy_fn: A function taking two tensors and returning the reduction of the two. + device_sets: Tuple of virtual devices to run test on. """ if not test.is_gpu_available(): return # Test requires access to a GPU @@ -74,26 +80,28 @@ class NcclTestCase(test.TestCase): # same communicator across multiple sessions. with self.test_session(use_gpu=True) as sess: - for devices in [['/device:GPU:1', '/device:GPU:2', '/device:GPU:0'], - ['/device:GPU:1', '/device:GPU:0']]: + for devices in device_sets: shape = (3, 4) random = (np.random.random_sample(shape) - .5) * 1024 - tensors = [random.astype(dtype)] * len(devices) + tensors = [] + for _ in devices: + tensors.append(random.astype(dtype)) np_ans = tensors[0] for t in tensors[1:]: np_ans = numpy_fn(np_ans, t) - reduce_tensors, reduce_ops = nccl_reduce(tensors, devices) + reduce_tensors = nccl_reduce(tensors, devices) self.assertNotEmpty(reduce_tensors) # Test shape inference. for r in reduce_tensors: self.assertEqual(shape, r.get_shape()) + result_tensors = [array_ops.identity(t) for t in reduce_tensors] + # Test execution and results. - nccl_results = sess.run(reduce_tensors + reduce_ops) - for r in nccl_results[:len(reduce_tensors)]: - self.assertAllClose(r, np_ans) + for t in sess.run(result_tensors): + self.assertAllClose(t, np_ans) def _TestGradient(self, nccl_reduce, numpy_fn): """Tests the gradient of nccl_reduce. @@ -106,14 +114,12 @@ class NcclTestCase(test.TestCase): reduction of the two. """ def _Gradient(tensors, devices): - reduce_tensors, _ = nccl_reduce(tensors, devices) - tensor_ops = [t.op for t in reduce_tensors] - d_tensors = _DeviceTensors(tensors, devices) - grad_tensors = [ - ops.get_gradient_function(op)(op, loss) - for op, loss in zip(tensor_ops, d_tensors) - ] - return grad_tensors, [] + inputs = [array_ops.placeholder(t.dtype, t.shape) for t in tensors] + reduce_tensors = nccl_reduce(inputs, devices) + losses = _DeviceTensors(tensors, [t.device for t in reduce_tensors]) + grads = gradients.gradients( + reduce_tensors, inputs, losses, colocate_gradients_with_ops=True) + return [g for g in grads if g is not None] self._Test(_Gradient, numpy_fn) @@ -142,27 +148,40 @@ class SingleReduceTest(NcclTestCase): def testSum(self): self._Test(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x + y) + def testSumGrad(self): + self._TestGradient(partial(_NcclReduce, nccl.reduce_sum), lambda x, y: x) + class BroadcastTest(NcclTestCase): def testBroadcast(self): self._Test(_NcclBroadcast, lambda x, y: x) + def testBroadcastSingleDevice(self): + # Broadcasts on a single device are removed completely during rewrite. + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:GPU:0'],)) + + def testBroadcastToCpuError(self): + # Broadcasts to CPU is not supported. + with self.assertRaisesRegexp( + errors.NotFoundError, + "No registered '_NcclBroadcastRecv' OpKernel for CPU devices"): + self._Test(_NcclBroadcast, lambda x, y: x, + (['/device:GPU:0', '/device:CPU:0'],)) + class CombinedTest(NcclTestCase): """Test all-reduce vs. single-reduce plus broadcast in one session.run.""" - def _combined(self, tensors, devices): - all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices)[0] - single_reduce_tensors, single_reduce_ops = _NcclReduce( - nccl.reduce_sum, tensors, devices) - broadcast_tensors, broadcast_ops = _NcclBroadcast(single_reduce_tensors, - devices) - all_tensors = all_reduce_tensors + single_reduce_tensors + broadcast_tensors - return all_tensors, single_reduce_ops + broadcast_ops + def _Combined(self, tensors, devices): + all_reduce_tensors = _NcclAllReduce(nccl.all_sum, tensors, devices) + single_reduce_tensors = _NcclReduce(nccl.reduce_sum, tensors, devices) + broadcast_tensors = _NcclBroadcast(single_reduce_tensors, devices) + return all_reduce_tensors + broadcast_tensors def testCombined(self): - self._Test(self._combined, lambda x, y: x + y) + self._Test(self._Combined, lambda x, y: x + y) if __name__ == '__main__': diff --git a/tensorflow/contrib/nearest_neighbor/BUILD b/tensorflow/contrib/nearest_neighbor/BUILD index 84d59cc4be87488ec55df54af16ae0b27a37fdd0..9500c18b1df9d772dfb827bc2b3d33e0a65974f6 100644 --- a/tensorflow/contrib/nearest_neighbor/BUILD +++ b/tensorflow/contrib/nearest_neighbor/BUILD @@ -41,18 +41,14 @@ tf_gen_op_wrapper_py( tf_custom_op_py_library( name = "nearest_neighbor_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - dso = [ - ":python/ops/_nearest_neighbor_ops.so", - ], - kernels = [ - ":nearest_neighbor_ops_kernels", - ], + dso = [":python/ops/_nearest_neighbor_ops.so"], + kernels = [":nearest_neighbor_ops_kernels"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ - ":nearest_neighbor_ops_pywrapper", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", ], ) @@ -70,9 +66,7 @@ tf_kernel_library( cc_library( name = "heap", - hdrs = [ - "kernels/heap.h", - ], + hdrs = ["kernels/heap.h"], ) tf_cc_test( @@ -81,17 +75,14 @@ tf_cc_test( srcs = ["kernels/heap_test.cc"], deps = [ ":heap", - "//tensorflow/core:test", "//tensorflow/core:test_main", - "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", ], ) cc_library( name = "hyperplane_lsh_probes", - hdrs = [ - "kernels/hyperplane_lsh_probes.h", - ], + hdrs = ["kernels/hyperplane_lsh_probes.h"], deps = [ ":heap", "//third_party/eigen3", @@ -107,6 +98,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", + "//tensorflow/core/kernels:ops_testutil", ], ) diff --git a/tensorflow/contrib/nn/BUILD b/tensorflow/contrib/nn/BUILD index 4b7288e235b566e13bfebe0425a6bbbe5efa0ae1..56a24ac77f0b9a87b6e4db48cddacdf35f4855d0 100644 --- a/tensorflow/contrib/nn/BUILD +++ b/tensorflow/contrib/nn/BUILD @@ -18,6 +18,7 @@ py_library( "python/ops/alpha_dropout.py", "python/ops/cross_entropy.py", "python/ops/sampling_ops.py", + "python/ops/scaled_softplus.py", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], @@ -26,8 +27,10 @@ py_library( "//tensorflow/python:dtypes", "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_ops", + "//tensorflow/python:function", "//tensorflow/python:math_ops", "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", "//tensorflow/python:random_ops", "//tensorflow/python:tensor_shape", "//tensorflow/python:tensor_util", @@ -36,34 +39,48 @@ py_library( ) py_test( - name = "sampling_ops_test", + name = "alpha_dropout_test", size = "small", - srcs = ["python/ops/sampling_ops_test.py"], + srcs = ["python/ops/alpha_dropout_test.py"], srcs_version = "PY2AND3", deps = [ ":nn_py", + "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:nn", + "//tensorflow/python:random_ops", ], ) py_test( - name = "alpha_dropout_test", + name = "sampling_ops_test", size = "small", - srcs = ["python/ops/alpha_dropout_test.py"], + srcs = ["python/ops/sampling_ops_test.py"], srcs_version = "PY2AND3", deps = [ ":nn_py", - "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:nn", - "//tensorflow/python:random_ops", + ], +) + +py_test( + name = "scaled_softplus_test", + size = "small", + srcs = ["python/ops/scaled_softplus_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":nn_py", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:gradient_checker", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/nn/__init__.py b/tensorflow/contrib/nn/__init__.py index 2cfeaa955ddc5aae9139205d4c6b3df3057a1613..3bf795d19aad73ec37c0485fe1900a7d8ac43137 100644 --- a/tensorflow/contrib/nn/__init__.py +++ b/tensorflow/contrib/nn/__init__.py @@ -18,7 +18,9 @@ @@deprecated_flipped_softmax_cross_entropy_with_logits @@deprecated_flipped_sparse_softmax_cross_entropy_with_logits @@deprecated_flipped_sigmoid_cross_entropy_with_logits +@@nth_element @@rank_sampled_softmax_loss +@@scaled_softplus """ from __future__ import absolute_import @@ -26,9 +28,11 @@ from __future__ import division from __future__ import print_function # pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.nn.python.ops.alpha_dropout import * from tensorflow.contrib.nn.python.ops.cross_entropy import * from tensorflow.contrib.nn.python.ops.sampling_ops import * -from tensorflow.contrib.nn.python.ops.alpha_dropout import * +from tensorflow.contrib.nn.python.ops.scaled_softplus import * +from tensorflow.python.ops.nn_ops import nth_element # pylint: enable=unused-import,wildcard-import from tensorflow.python.util.all_util import remove_undocumented diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus.py b/tensorflow/contrib/nn/python/ops/scaled_softplus.py new file mode 100644 index 0000000000000000000000000000000000000000..fcbfbc239ca5b8a1d4b17b403f99b7eb05db47b0 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus.py @@ -0,0 +1,115 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Support for scaled softplus, a smoothed version of ReLU.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import function +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + + +def _reduce_and_reshape_grad(g, t): + """Returns the gradient, sum-reduced and reshaped to `t`'s shape.""" + shape = array_ops.shape(t) + g_shape = array_ops.shape(g) + # pylint: disable=protected-access + bcast_dims, _ = gen_array_ops._broadcast_gradient_args(shape, g_shape) + # pylint: enable=protected-access + return array_ops.reshape(math_ops.reduce_sum(g, bcast_dims), shape) + + +def scaled_softplus(x, alpha, clip=None, name=None): + """Returns `y = alpha * ln(1 + exp(x / alpha))` or `min(y, clip)`. + + This can be seen as a softplus applied to the scaled input, with the output + appropriately scaled. As `alpha` tends to 0, `scaled_softplus(x, alpha)` tends + to `relu(x)`. The clipping is optional. As alpha->0, scaled_softplus(x, alpha) + tends to relu(x), and scaled_softplus(x, alpha, clip=6) tends to relu6(x). + + Note: the gradient for this operation is defined to depend on the backprop + inputs as well as the outputs of this operation. + + Args: + x: A `Tensor` of inputs. + alpha: A `Tensor`, indicating the amount of smoothness. The caller + must ensure that `alpha > 0`. + clip: (optional) A `Tensor`, the upper bound to clip the values. + name: A name for the scope of the operations (optional). + + Returns: + A tensor of the size and type determined by broadcasting of the inputs. + + """ + clipping = clip is not None + with ops.name_scope(name, 'scaled_softplus', + [x, alpha] + ([clip] if clipping else [])): + x = ops.convert_to_tensor(x, name='x') + dtype = x.dtype + alpha = ops.convert_to_tensor(alpha, dtype=dtype, name='alpha') + # Compute the forward value. + y = alpha * nn.softplus(x / alpha) + if clipping: + clip = ops.convert_to_tensor(clip, dtype=dtype, name='clip') + y = math_ops.minimum(y, clip) + + def _grad(op, g): + """Backprop for scaled softplus, with optional clipping.""" + y, x, alpha = op.inputs[:3] + # Prevent the memory-expensive computations from happening before g is + # available. + with ops.control_dependencies([g]): + y = array_ops.identity(y) + clip_grad = [] + if clipping: + clip = op.inputs[3] + unclipped = math_ops.cast(y < clip, g.dtype) + clip_grad = [_reduce_and_reshape_grad(g * (1. - unclipped), clip)] + g *= unclipped + y /= alpha + emy = math_ops.exp(-y) + dy_dx = 1. - emy + # The eps below avoids log(0). Note that t*log(t) -> 0 as t->0. + eps = 1e-8 + dy_dalpha = y * emy - dy_dx * math_ops.log(dy_dx + eps) + # Backprop to the actual inputs, but not to the output. + return [None, + _reduce_and_reshape_grad(g * dy_dx, x), + _reduce_and_reshape_grad(g * dy_dalpha, alpha)] + clip_grad + + if clipping: + @function.Defun(dtype, dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_clip_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], + python_grad_func=_grad) + def _forward_helper_clip(y, x, alpha, clip): + del x, alpha, clip # Unused. + return y + return _forward_helper_clip(y, x, alpha, clip) + # No clipping. + @function.Defun(dtype, dtype, dtype, + func_name='ScaledSoftplusHelper_%s' % dtype.name, + shape_func=lambda op: [op.inputs[0].shape], + python_grad_func=_grad) + def _forward_helper(y, x, alpha): + del x, alpha # Unused. + return y + return _forward_helper(y, x, alpha) + diff --git a/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py new file mode 100644 index 0000000000000000000000000000000000000000..b978343c6a79af856d833b0ab8002c256ce478e0 --- /dev/null +++ b/tensorflow/contrib/nn/python/ops/scaled_softplus_test.py @@ -0,0 +1,78 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for scaled_softplus.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.nn.python.ops.scaled_softplus import scaled_softplus +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import gradient_checker +from tensorflow.python.platform import test + + +class ScaledSoftplusTest(test.TestCase): + + def test(self): + np.random.seed(1) # Make it reproducible. + x = np.random.randn(3, 4).astype(np.float32) + x64 = np.random.randn(3, 4).astype(np.float64) + alpha = np.random.rand() + 0.01 + clip = np.float32(0.1) + y = np.minimum(alpha * np.log(1. + np.exp(x / alpha)), clip) + y64 = alpha * np.log(1. + np.exp(x64 / alpha)) + with self.test_session(use_gpu=True) as sess: + z = scaled_softplus(constant_op.constant(x), alpha, clip) + z64 = scaled_softplus(constant_op.constant(x64), alpha) + z, z64 = sess.run([z, z64]) + eps = 1e-6 + self.assertAllClose(y, z, eps) + self.assertAllClose(y64, z64, eps) + + def testGradient(self): + np.random.seed(1) # Make it reproducible. + x_shape = [5, 10] + x_np = np.random.randn(*x_shape).astype(np.float32) + alpha_np = np.float32(np.random.rand(1, x_shape[1]) + 0.01) + clip_np = np.float32(np.random.rand(x_shape[0], 1) * 5.) + with self.test_session(use_gpu=True): + x_tf = constant_op.constant(x_np) + alpha_tf = constant_op.constant(alpha_np) + clip_tf = constant_op.constant(clip_np) + y_tf = scaled_softplus(x_tf, alpha_tf) + z_tf = scaled_softplus(x_tf, alpha_tf, clip_tf * 0.1) + err = gradient_checker.compute_gradient_error([x_tf, alpha_tf], + [x_shape, alpha_np.shape], + y_tf, x_shape, + [x_np, alpha_np], + delta=0.002) + err_clip = gradient_checker.compute_gradient_error( + [x_tf, alpha_tf, clip_tf], + [x_shape, alpha_np.shape, clip_np.shape], + z_tf, x_shape, + [x_np, alpha_np, clip_np], + delta=0.002) + eps = 2e-4 + self.assertLess(err, eps) + self.assertLess(err_clip, eps) + + +if __name__ == '__main__': + test.main() + + diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD index b5a67206f3433ab3cf5ee5594557aadf8a09983b..096d2270e4c2d046a8dc8982bf03a648a195c667 100644 --- a/tensorflow/contrib/opt/BUILD +++ b/tensorflow/contrib/opt/BUILD @@ -86,9 +86,9 @@ py_test( ], deps = [ ":opt_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", "//third_party/py/numpy", @@ -119,13 +119,13 @@ py_test( deps = [ ":opt_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:session", "//tensorflow/python:variables", "//third_party/py/numpy", ], @@ -139,12 +139,17 @@ tf_py_test( "//third_party/py/numpy", "//tensorflow/python:client", "//tensorflow/python:client_testlib", + "//tensorflow/python:data_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variables", ], + tags = [ + "no_oss", # Flaky due to port collisions + ], ) filegroup( diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 745dc2f8366a319dc94246228a6cc3efc12a53b8..1bf40ab6b26c6ad1f9658a4b0ad93527fe609698 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -25,7 +25,10 @@ py_library( srcs = ["__init__.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], - deps = [":predictor_factories"], + deps = [ + ":predictor_factories", + "//tensorflow/python:util", + ], ) py_library( @@ -58,7 +61,6 @@ py_library( "//tensorflow/python:session", "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", - "//tensorflow/python/saved_model:signature_def_utils", ], ) diff --git a/tensorflow/contrib/predictor/predictor.py b/tensorflow/contrib/predictor/predictor.py index dbc0028259ebe50bdbe8dee9ef3ccff1aff5507c..28fa815684dd5e242f82d51968d856553315e8d5 100644 --- a/tensorflow/contrib/predictor/predictor.py +++ b/tensorflow/contrib/predictor/predictor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Abstract base class for all predictors.""" from __future__ import absolute_import @@ -66,8 +65,9 @@ class Predictor(object): expected_keys = set(self.feed_tensors.keys()) unexpected_keys = input_keys - expected_keys if unexpected_keys: - raise ValueError('Got unexpected keys in input_dict: {}'.format( - unexpected_keys)) + raise ValueError( + 'Got unexpected keys in input_dict: {}\nexpected: {}'.format( + unexpected_keys, expected_keys)) feed_dict = {} for key in self.feed_tensors.keys(): diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..935af80e7a0cb94b9ccdc52b48a73cecc5beb299 --- /dev/null +++ b/tensorflow/contrib/quantize/BUILD @@ -0,0 +1,246 @@ +package(default_visibility = ["//tensorflow:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +py_library( + name = "common", + srcs = ["python/common.py"], + srcs_version = "PY2AND3", + deps = [], +) + +py_library( + name = "graph_matcher", + srcs = [ + "python/graph_matcher.py", + ], + srcs_version = "PY2AND3", + deps = [], +) + +py_test( + name = "graph_matcher_test", + size = "small", + srcs = ["python/graph_matcher_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":graph_matcher", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "input_to_ops", + srcs = ["python/input_to_ops.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ], +) + +py_test( + name = "input_to_ops_test", + size = "small", + srcs = ["python/input_to_ops_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":input_to_ops", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "fold_batch_norms", + srcs = ["python/fold_batch_norms.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ":graph_matcher", + ":input_to_ops", + "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn", + "//tensorflow/python:nn_ops", + ], +) + +py_test( + name = "fold_batch_norms_test", + srcs = ["python/fold_batch_norms_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":fold_batch_norms", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_library( + name = "copy_graph", + srcs = ["python/copy_graph.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "copy_graph_test", + size = "small", + srcs = ["python/copy_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":copy_graph", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +py_library( + name = "quant_ops", + srcs = ["python/quant_ops.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:state_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + ], +) + +py_library( + name = "quantize", + srcs = ["python/quantize.py"], + srcs_version = "PY2AND3", + deps = [ + ":common", + ":input_to_ops", + ":quant_ops", + "//tensorflow/contrib/graph_editor:graph_editor_py", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + ], +) + +py_test( + name = "quantize_test", + size = "small", + srcs = ["python/quantize_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quantize", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + ], +) + +py_test( + name = "quantize_parameterized_test", + size = "large", + srcs = ["python/quantize_parameterized_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":fold_batch_norms", + ":quantize", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:training", + ], +) + +py_library( + name = "quantize_graph", + srcs = [ + "__init__.py", + "python/quantize_graph.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":copy_graph", + ":fold_batch_norms", + ":quantize", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variables", + ], +) + +py_test( + name = "quantize_graph_test", + size = "small", + srcs = ["python/quantize_graph_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":quantize_graph", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:constant_op", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:init_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:platform_test", + "//tensorflow/python:variables", + ], +) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/contrib/quantize/__init__.py b/tensorflow/contrib/quantize/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5d4e4575c935e0a888c6e5e4d0db640d93e1bd49 --- /dev/null +++ b/tensorflow/contrib/quantize/__init__.py @@ -0,0 +1,32 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Functions for rewriting graphs for quantized training.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import,wildcard-import,line-too-long +from tensorflow.contrib.quantize.python.quantize_graph import * +# pylint: enable=unused-import,wildcard-import,line-too-long + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + "create_eval_graph", + "create_training_graph", +] + +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py new file mode 100644 index 0000000000000000000000000000000000000000..d0b0674c31239ee903f5ab7ef9ae0262bb20d189 --- /dev/null +++ b/tensorflow/contrib/quantize/python/common.py @@ -0,0 +1,88 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Constants used across this package.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import re + +# Skip all operations that are backprop related or export summaries. +SKIPPED_PREFIXES = ( + 'gradients/', 'RMSProp/', 'Adagrad/', 'Const_', 'HistogramSummary', + 'ScalarSummary') + +# Valid activation ops for quantization end points. +_ACTIVATION_OP_SUFFIXES = ['/Relu6', '/Relu', '/Identity'] + +# Regular expression for recognizing nodes that are part of batch norm group. +_BATCHNORM_RE = re.compile(r'^(.*)/BatchNorm/batchnorm') + + +def BatchNormGroups(graph): + """Finds batch norm layers, returns their prefixes as a list of strings. + + Args: + graph: Graph to inspect. + + Returns: + List of strings, prefixes of batch norm group names found. + """ + bns = [] + for op in graph.get_operations(): + match = _BATCHNORM_RE.search(op.name) + if match: + bn = match.group(1) + if not bn.startswith(SKIPPED_PREFIXES): + bns.append(bn) + # Filter out duplicates. + return list(collections.OrderedDict.fromkeys(bns)) + + +def GetEndpointActivationOp(graph, prefix): + """Returns an Operation with the given prefix and a valid end point suffix. + + Args: + graph: Graph where to look for the operation. + prefix: String, prefix of Operation to return. + + Returns: + The Operation with the given prefix and a valid end point suffix or None if + there are no matching operations in the graph for any valid suffix + """ + for suffix in _ACTIVATION_OP_SUFFIXES: + activation = _GetOperationByNameDontThrow(graph, prefix + suffix) + if activation: + return activation + return None + + +def _GetOperationByNameDontThrow(graph, name): + """Returns an Operation with the given name. + + Args: + graph: Graph where to look for the operation. + name: String, name of Operation to return. + + Returns: + The Operation with the given name. None if the name does not correspond to + any operation in the graph + """ + try: + return graph.get_operation_by_name(name) + except KeyError: + return None diff --git a/tensorflow/contrib/quantize/python/copy_graph.py b/tensorflow/contrib/quantize/python/copy_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..0376fcba82b99feabdba3b683f9db9a32db51efb --- /dev/null +++ b/tensorflow/contrib/quantize/python/copy_graph.py @@ -0,0 +1,32 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to copy a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.training import saver as saver_lib + + +def CopyGraph(graph): + """Return a copy of graph.""" + meta_graph = saver_lib.export_meta_graph( + graph=graph, collection_list=graph.get_all_collection_keys()) + graph_copy = ops.Graph() + with graph_copy.as_default(): + _ = saver_lib.import_meta_graph(meta_graph) + return graph_copy diff --git a/tensorflow/contrib/quantize/python/copy_graph_test.py b/tensorflow/contrib/quantize/python/copy_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7ff9ad9f8412d7076bf12d6cf10772244444013f --- /dev/null +++ b/tensorflow/contrib/quantize/python/copy_graph_test.py @@ -0,0 +1,55 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for copy_graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import copy_graph +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class CopyGraphTest(test_util.TensorFlowTestCase): + + def _CompareNodeInGraph(self, node, graph): + graph_node = graph.get_operation_by_name(node.name) + self.assertEqual(str(node.node_def), str(graph_node.node_def)) + + def testCopyGraph(self): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + c = a + b + graph_copy = copy_graph.CopyGraph(graph) + # Ensure that the three original nodes are in the new graph. + # import_meta_graph also adds a saver node to the graph which we don't care + # about in this specific use case. + for tensor in [a, b, c]: + self._CompareNodeInGraph(tensor.op, graph_copy) + # Test that the graph collections are the same. + for key in graph.get_all_collection_keys(): + self.assertEqual( + len(graph.get_collection(key)), + len(graph_copy.get_collection(key)), 'Collection %s differs.') + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py new file mode 100644 index 0000000000000000000000000000000000000000..647d4044001f7be701037d07dc46db86c0aa3a0e --- /dev/null +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -0,0 +1,570 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to fold batch norm into preceding convolution or FC layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from tensorflow.contrib import graph_editor +from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn +from tensorflow.python.ops import nn_ops + + +def FoldBatchNorms(graph): + """Finds batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + _FoldFusedBatchNorms(graph) + _FoldUnfusedBatchNorms(graph) + + +def _FoldFusedBatchNorms(graph): + """Finds fused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + for match in _FindFusedBatchNorms(graph): + scope, sep, _ = match.layer_op.name.rpartition('/') + # Make sure new ops are added to `graph` and put on the same device as + # `bn_op`. The '/' (i.e. `sep`) ensures that we reuse the existing scope + # named `scope`. Otherwise, TF creates a unique scope whose name starts with + # `scope`. + with graph.as_default(), graph.name_scope(scope + sep), ops.device( + match.bn_op.device): + # new weights = old weights * gamma / sqrt(variance + epsilon) + # new biases = -mean * gamma / sqrt(variance + epsilon) + beta + multiplier_tensor = match.gamma_tensor * math_ops.rsqrt( + match.variance_tensor + match.bn_op.get_attr('epsilon')) + bias_tensor = math_ops.subtract( + match.beta_tensor, match.mean_tensor * multiplier_tensor, name='bias') + + # The shape of depthwise weights is different, so we need to reshape the + # multiplier_tensor to ensure that the scaled_weight_tensor has the + # expected shape. + if match.layer_op.type == 'DepthwiseConv2dNative': + new_shape = [ + match.weight_tensor.get_shape().as_list()[2], + match.weight_tensor.get_shape().as_list()[3] + ] + multiplier_tensor = array_ops.reshape( + multiplier_tensor, new_shape, name='scale_reshape') + + # TODO(suharshs): This naming of the following ops needs to carefully + # follow the naming expected by quantize.py. Generalize the quantize code + # to not require these delicate naming conventions. + scaled_weight_tensor = math_ops.multiply( + match.weight_tensor, multiplier_tensor, name='mul_fold') + + new_layer_tensor = _CloneWithNewOperands( + match.layer_op, match.input_tensor, scaled_weight_tensor) + + bias_add_tensor = math_ops.add( + new_layer_tensor, bias_tensor, name='add_fold') + + nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, + match.output_tensor) + if nodes_modified_count != 1: + raise ValueError( + 'Unexpected inputs to op: %s' % match.output_tensor.name) + + +def _CloneWithNewOperands(layer_op, input_tensor, weight_tensor): + """Clones layer_op with input_tensor and weight_tensor as new inputs.""" + new_layer_name = layer_op.name.split('/')[-1] + '_Fold' + if layer_op.type == 'Conv2D': + return nn_ops.conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + use_cudnn_on_gpu=layer_op.get_attr('use_cudnn_on_gpu'), + data_format=layer_op.get_attr('data_format'), + name=new_layer_name) + elif layer_op.type == 'MatMul': + return math_ops.matmul( + input_tensor, + weight_tensor, + transpose_a=layer_op.get_attr('transpose_a'), + transpose_b=layer_op.get_attr('transpose_b'), + name=new_layer_name) + elif layer_op.type == 'DepthwiseConv2dNative': + return nn.depthwise_conv2d( + input_tensor, + weight_tensor, + strides=layer_op.get_attr('strides'), + padding=layer_op.get_attr('padding'), + name=new_layer_name) + else: + raise ValueError('Cannot handle operation of type: %s' % layer_op.type) + + +def _FindFusedBatchNorms(graph): + """Finds all ops and tensors related to found FusedBatchNorms. + + Args: + graph: Graph to inspect. + + Yields: + _FusedBatchNormMatches. + """ + input_pattern = graph_matcher.OpTypePattern('*') + weight_pattern = graph_matcher.OpTypePattern('*') + gamma_pattern = graph_matcher.OpTypePattern('*') + beta_pattern = graph_matcher.OpTypePattern('*') + mean_pattern = graph_matcher.OpTypePattern('*') + variance_pattern = graph_matcher.OpTypePattern('*') + + conv_pattern = graph_matcher.OpTypePattern( + 'Conv2D|DepthwiseConv2dNative', inputs=[input_pattern, weight_pattern]) + # MatMul has a Reshape between it and FusedBatchNorm. + matmul_pattern = graph_matcher.OpTypePattern( + 'MatMul', inputs=[input_pattern, weight_pattern]) + matmul_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', inputs=[matmul_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + conv_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_batch_norm_pattern = graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + matmul_reshape_pattern, gamma_pattern, beta_pattern, mean_pattern, + variance_pattern + ]) + matmul_bn_output_reshape_pattern = graph_matcher.OpTypePattern( + 'Reshape', + inputs=[matmul_batch_norm_pattern, + graph_matcher.OpTypePattern('*')]) + + conv_matcher = graph_matcher.GraphMatcher(conv_batch_norm_pattern) + matmul_matcher = graph_matcher.GraphMatcher(matmul_bn_output_reshape_pattern) + + def _GetCommonTensors(match_result): + """Gets tensors needed for FusedBatchNormMatch from match_result.""" + input_tensor = match_result.get_tensor(input_pattern) + weight_tensor = match_result.get_tensor(weight_pattern) + gamma_tensor = match_result.get_tensor(gamma_pattern) + beta_tensor = match_result.get_tensor(beta_pattern) + # FusedBatchNorm in training is different from that in inference. It takes + # empty 'mean' and empty 'variance', and produces the mean and the variance + # of the batch. Therefore, when is_training is true, mean_tensor and + # variance_tensor point to 1st and 2nd (0-based) output of bn_op, + # respectively; when is_training is false, they point to bn_op's inputs. + is_training = bn_op.get_attr('is_training') + if is_training: + mean_tensor = bn_op.outputs[1] + variance_tensor = bn_op.outputs[2] + else: + mean_tensor = match_result.get_tensor(mean_pattern) + variance_tensor = match_result.get_tensor(variance_pattern) + return (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) + + for match_result in conv_matcher.match_graph(graph): + layer_op = match_result.get_op(conv_pattern) + bn_op = match_result.get_op(conv_batch_norm_pattern) + # In the case of convolution the output_tensor is the output of bn_op. + output_tensor = bn_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + for match_result in matmul_matcher.match_graph(graph): + layer_op = match_result.get_op(matmul_pattern) + bn_op = match_result.get_op(matmul_batch_norm_pattern) + # In the MatMul case, the output of batch norm is reshaped back into a + # 2D tensor, so the output_tensor is the output of the Reshape op. + output_reshape_op = match_result.get_op(matmul_bn_output_reshape_pattern) + output_tensor = output_reshape_op.outputs[0] + + (input_tensor, weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor) = _GetCommonTensors(match_result) + yield _FusedBatchNormMatch( + layer_op=layer_op, + bn_op=bn_op, + output_tensor=output_tensor, + input_tensor=input_tensor, + weight_tensor=weight_tensor, + gamma_tensor=gamma_tensor, + beta_tensor=beta_tensor, + mean_tensor=mean_tensor, + variance_tensor=variance_tensor) + + +class _FusedBatchNormMatch(object): + """Contains all information related to a found FusedBatchNorm.""" + + def __init__(self, layer_op, bn_op, output_tensor, input_tensor, + weight_tensor, gamma_tensor, beta_tensor, mean_tensor, + variance_tensor): + self._layer_op = layer_op + self._bn_op = bn_op + self._output_tensor = output_tensor + self._input_tensor = input_tensor + self._weight_tensor = weight_tensor + self._gamma_tensor = gamma_tensor + self._beta_tensor = beta_tensor + self._mean_tensor = mean_tensor + self._variance_tensor = variance_tensor + + @property + def layer_op(self): + return self._layer_op + + @property + def bn_op(self): + return self._bn_op + + @property + def output_tensor(self): + return self._output_tensor + + @property + def input_tensor(self): + return self._input_tensor + + @property + def weight_tensor(self): + return self._weight_tensor + + @property + def gamma_tensor(self): + return self._gamma_tensor + + @property + def beta_tensor(self): + return self._beta_tensor + + @property + def mean_tensor(self): + return self._mean_tensor + + @property + def variance_tensor(self): + return self._variance_tensor + + +def _FoldUnfusedBatchNorms(graph): + """Finds unfused batch norm layers and folds them into preceding layers. + + Folding only affects the following layers: Conv2D, fully connected, depthwise + convolution. + + Args: + graph: Graph to walk and modify. + + Raises: + ValueError: When batch norm folding fails. + """ + input_to_ops_map = input_to_ops.InputToOps(graph) + + for bn in common.BatchNormGroups(graph): + has_scaling = _HasScaling(graph, input_to_ops_map, bn) + + # The mangling code intimately depends on BatchNorm node's internals. + original_op, folded_op = _CreateFoldedOp(graph, bn, has_scaling=has_scaling) + + activation = common.GetEndpointActivationOp(graph, bn) + if activation: + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[activation]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % activation.name) + continue + + # Treat consumer ops in bypass modules differently since they have Add + # operations instead of Relu* above. + add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) + add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') + nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], + [original_op.outputs[0]], + can_modify=[add_bypass]) + if nodes_modified_count != 1: + raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) + + +def _HasScaling(graph, input_to_ops_map, bn): + r"""Checks if batch norm has scaling enabled. + + Difference between batch norm with scaling and without is that with scaling: + + Rsqrt -> mul -> mul_1 + \-> mul_2 + + where + mul multiplies gamma by inverse square root of EMA of batch variance, + mul_1 multiplies output of mul with output from the base operation + (convolution, FC or depthwise convolution), + mul_2 multiplies output of mul with EMA of batch mean, + and without scaling: + + Rsqrt -> mul + \-> mul_1 + + where + mul multiplies the inverse square root of EMA of batch variance with output + from the base operation, + mul_1 multiplies inverse square root of EMA of batch variance with EMA + of batch mean. + + Args: + graph: Graph to inspect. + input_to_ops_map: InputToOps object containing mapping from tensor's name + to ops that take it as input. + bn: Batch norm layer prefix string. + + Returns: + A boolean indicating whether this batch norm layer has scaling enabled. + """ + rsqrt_op = graph.get_operation_by_name(bn + '/BatchNorm/batchnorm/Rsqrt') + rsqrt_consumers = input_to_ops_map.ConsumerOperations(rsqrt_op) + + return sum(1 for op in rsqrt_consumers if op.type == 'Mul') == 1 + + +def _CreateFoldedOp(graph, context, has_scaling): + """Folds in batch norm layer into preceding convolution or FC layer. + + Creates 3 new nodes, connects their inputs and adds them to the graph: + mul is cloned into mul_fold, Conv2D or MatMul, or DepthwiseConv2d is cloned + into respective *_Fold, add is cloned into add_fold. + + Args: + graph: Graph to modify. + context: String, batch norm context, i.e. node into which BatchNorm is + nested. + has_scaling: Whether the batch norm has scaling enabled. + + Raises: + ValueError: When operation type is not supported, or input and output tensor + shapes mismatch for created operations: mul_fold, add_fold. + + Returns: + A pair of Operations, the first is the original consumer node of the batch + norm (../BatchNorm/batchnorm/add_1), the second is the consumer node of + the folded graph (add_fold). + """ + mul_scale_name = 'mul_1' if has_scaling else 'mul' + mul_scale = graph.get_operation_by_name(context + + '/BatchNorm/batchnorm/' + + mul_scale_name) + op_below = mul_scale.inputs[0].op + weights = op_below.inputs[1] + + # Special handling for weights of depthwise convolution. + if op_below.type == 'DepthwiseConv2dNative': + new_shape = [weights.get_shape().as_list()[2], + weights.get_shape().as_list()[3]] + scale_name = 'mul' if has_scaling else 'Rsqrt' + scale = graph.get_operation_by_name(context + '/BatchNorm/batchnorm/' + + scale_name) + scale = array_ops.reshape(scale.outputs[0], new_shape, + context + '/scale_reshape') + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', + [(0, weights), (1, scale)]) + elif op_below.type in ['Conv2D', 'MatMul']: + mul_fold = _CloneOp(mul_scale, context + '/mul_fold', [(0, weights)]) + else: + raise ValueError('Cannot handle operation of type: %s' % op_below.op) + _AssertShapesMatch('mul_fold', mul_fold.inputs[0], mul_fold.outputs[0]) + + conv_or_fc_folded = _CloneOp(op_below, op_below.name + '_Fold', + [(1, mul_fold.outputs[0])]) + + add_shift = graph.get_operation_by_name(context + + '/BatchNorm/batchnorm/add_1') + add_fold = _CloneOp(add_shift, context + '/add_fold', + [(0, conv_or_fc_folded.outputs[0])]) + _AssertShapesMatch('add_fold', add_fold.inputs[0], add_fold.outputs[0]) + return add_shift, add_fold + + +def _CloneOp(op, new_name, new_inputs): + """Clones a given op, replaces its name and some of its inputs. + + Args: + op: Operation to modify. + new_name: String, a new name to set on cloned op. + new_inputs: A list of tuples (idx, tensor), each input with corresponding + index will be replaced by the given Tensor in the cloned op. + + Returns: + Operation, the cloned op. + + Raises: + TypeError: When Operation type is not supported. + ValueError: When input shapes are incompatible. + """ + inputs = list(op.inputs) + for new_input in new_inputs: + inputs[new_input[0]] = new_input[1] + return _OP_CLONER.Clone(op, inputs, new_name) + + +class _OpCloner(object): + """Helper class that clones tf.Operations based on their type.""" + + def __init__(self): + self.op_type_to_action = { + 'Mul': self._CloneMul, + 'Add': self._CloneAdd, + 'Conv2D': self._CloneConv2d, + 'DepthwiseConv2dNative': self._CloneDepthwiseConv2d, + 'MatMul': self._CloneMatMul, + } + + def _CloneMul(self, op, inputs, new_name): + del op # Unused. + return math_ops.multiply(inputs[0], inputs[1], name=new_name).op + + def _CloneAdd(self, op, inputs, new_name): + del op # Unused. + return math_ops.add(inputs[0], inputs[1], name=new_name).op + + def _CloneConv2d(self, op, inputs, new_name): + input_tensor = inputs[0] + weights = inputs[1] + self._AssertConvShapes(op.name, input_tensor, weights) + return nn_ops.conv2d( + input_tensor, + weights, + strides=op.get_attr('strides'), + padding=op.get_attr('padding'), + use_cudnn_on_gpu=op.get_attr('use_cudnn_on_gpu'), + data_format=op.get_attr('data_format'), + name=new_name).op + + def _CloneDepthwiseConv2d(self, op, inputs, new_name): + input_tensor = inputs[0] + weights = inputs[1] + self._AssertConvShapes(op.name, input_tensor, weights) + return nn.depthwise_conv2d( + input_tensor, + weights, + strides=op.get_attr('strides'), + padding=op.get_attr('padding'), + name=new_name).op + + def _CloneMatMul(self, op, inputs, new_name): + weights = inputs[0] + input_tensor = inputs[1] + self._AssertFCShapes(op.name, weights, input_tensor) + return math_ops.matmul( + weights, + input_tensor, + transpose_a=op.get_attr('transpose_a'), + transpose_b=op.get_attr('transpose_b'), + name=new_name).op + + def Clone(self, op, inputs, new_name): + try: + return self.op_type_to_action[op.type](op, inputs, new_name) + except KeyError: + raise TypeError('Unsupported operation type: %s' % op.type) + + def _AssertConvShapes(self, op_name, input_tensor, weights): + """Makes sure that convolution inputs have compatible shapes. + + Args: + op_name: Operation name, only used in error message. + input_tensor: Input that is convolved. + weights: Weights of the convolution filter. + + Raises: + ValueError: When input shapes are incompatible. + """ + input_shape = input_tensor.get_shape() + weights_shape = weights.get_shape() + if (len(input_shape) != 4 or len(weights_shape) != 4 or + input_shape[3] != weights_shape[2]): + raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % + (op_name, input_shape, weights_shape)) + + def _AssertFCShapes(self, op_name, weights, input_tensor): + """Makes sure that FC layer inputs have compatible shapes. + + Args: + op_name: Operation name, only used in error message. + weights: Weights used in FC layer. + input_tensor: Input into FC layer. + + Raises: + ValueError: When input shapes are incompatible. + """ + weights_shape = weights.get_shape() + input_shape = input_tensor.get_shape() + if (len(weights_shape) != 2 or len(input_shape) != 2 or + weights_shape[1] != input_shape[0]): + raise ValueError('Incompatible shapes for op %s inputs: %s and %s' % + (op_name, weights_shape, input_shape)) + +_OP_CLONER = _OpCloner() + + +def _AssertShapesMatch(op_name, in_tensor, out_tensor): + """Makes sure that shapes of input and output tensors are compatible. + + Args: + op_name: String, operation name, only used in error message. + in_tensor: Tensor, input tensor. + out_tensor: Tensor, output tensor. + + Raises: + ValueError: When input and output tensors have different shapes. + """ + in_shape = in_tensor.get_shape() + out_shape = out_tensor.get_shape() + + if not in_shape.is_compatible_with(out_shape): + raise ValueError('%s should not change tensor shape: input %s, ' + 'output %s' % (op_name, in_shape, out_shape)) diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py new file mode 100644 index 0000000000000000000000000000000000000000..2cecf6851467f82675bd67bf1fb108e9a39df1b0 --- /dev/null +++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py @@ -0,0 +1,375 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for folding batch norm layers.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + +batch_norm = layers.batch_norm +conv2d = layers.conv2d +fully_connected = layers.fully_connected +separable_conv2d = layers.separable_conv2d + + +# TODO(suharshs): Use parameterized test once OSS TF supports it. +class FoldBatchNormsTest(test_util.TensorFlowTestCase): + + def _RunTestOverParameters(self, test_fn): + parameters_list = [ + # (relu, relu_op_name, with_bypass, has_scaling, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, False, False), + (nn_ops.relu, 'Relu', False, False, False), + (nn_ops.relu6, 'Relu6', True, False, False), + (nn_ops.relu, 'Relu', True, False, False), + (nn_ops.relu6, 'Relu6', False, True, False), + (nn_ops.relu, 'Relu', False, True, False), + (nn_ops.relu6, 'Relu6', True, True, False), + (nn_ops.relu, 'Relu', True, True, False), + # Fused batch norm always has scaling enabled. + (nn_ops.relu6, 'Relu6', False, True, True), + (nn_ops.relu, 'Relu', False, True, True), + (nn_ops.relu6, 'Relu6', True, True, True), + (nn_ops.relu, 'Relu', True, True, True), + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) + + def _TestFoldConv2d(self, relu, relu_op_name, with_bypass, has_scaling, + fused_batch_norm): + """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') + self.assertEqual(folded_conv.type, 'Conv2D') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/Conv2D_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldConv2d(self): + self._RunTestOverParameters(self._TestFoldConv2d) + + def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): + """Tests folding cases: inputs -> Conv2d with batch norm -> Relu*. + + Tests that folding works even with an input shape where some dimensions are + not known (i.e. None). + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + """ + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=(5, None, None, 3)) + out_depth = 3 if with_bypass else 32 + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold') + self.assertEqual(folded_conv.type, 'Conv2D') + self._AssertInputOpsAre(folded_conv, [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/Conv2D_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldConv2dUnknownShape(self): + self._RunTestOverParameters(self._TestFoldConv2dUnknownShape) + + def _TestFoldFullyConnectedLayer(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): + """Tests folding cases: inputs -> FC with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + """ + g = ops.Graph() + with g.as_default(): + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, [ + scope + '/weights/read', + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm) + ]) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/MatMul_Fold']) + + folded_conv = g.get_operation_by_name(scope + '/MatMul_Fold') + self.assertEqual(folded_conv.type, 'MatMul') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/MatMul_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldFullyConnectedLayer(self): + self._RunTestOverParameters(self._TestFoldFullyConnectedLayer) + + def _TestFoldDepthwiseConv2d(self, relu, relu_op_name, with_bypass, + has_scaling, fused_batch_norm): + """Tests folding: inputs -> DepthwiseConv2d with batch norm -> Relu*. + + Args: + relu: Callable that returns an Operation, a factory method for the Relu*. + relu_op_name: String, name of the Relu* operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Relu*. + has_scaling: Bool, when true the batch norm has scaling. + fused_batch_norm: Bool, when true the batch norm is fused. + """ + g = ops.Graph() + with g.as_default(): + batch_size, height, width = 5, 128, 128 + inputs = array_ops.zeros((batch_size, height, width, 3)) + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else relu + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams( + scale=has_scaling, fused=fused_batch_norm), + scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + relu(node, name='test/' + relu_op_name) + + fold_batch_norms.FoldBatchNorms(g) + + folded_mul = g.get_operation_by_name(scope + '/mul_fold') + self.assertEqual(folded_mul.type, 'Mul') + self._AssertInputOpsAre(folded_mul, + [scope + '/depthwise_weights/read', + scope + '/scale_reshape']) + self._AssertOutputGoesToOps(folded_mul, g, [scope + '/depthwise_Fold']) + + scale_reshape = g.get_operation_by_name(scope + '/scale_reshape') + self.assertEqual(scale_reshape.type, 'Reshape') + self._AssertInputOpsAre(scale_reshape, [ + self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm), + scope + '/scale_reshape/shape' + ]) + self._AssertOutputGoesToOps(scale_reshape, g, [scope + '/mul_fold']) + + folded_conv = g.get_operation_by_name(scope + '/depthwise_Fold') + self.assertEqual(folded_conv.type, 'DepthwiseConv2dNative') + self._AssertInputOpsAre(folded_conv, + [scope + '/mul_fold', inputs.op.name]) + self._AssertOutputGoesToOps(folded_conv, g, [scope + '/add_fold']) + + folded_add = g.get_operation_by_name(scope + '/add_fold') + self.assertEqual(folded_add.type, 'Add') + self._AssertInputOpsAre(folded_add, [ + scope + '/depthwise_Fold', + self._BathNormBiasName(scope, fused_batch_norm) + ]) + output_op_names = ['test/Add' if with_bypass else 'test/' + relu_op_name] + self._AssertOutputGoesToOps(folded_add, g, output_op_names) + + def testFoldDepthwiseConv2d(self): + self._RunTestOverParameters(self._TestFoldDepthwiseConv2d) + + def _BatchNormParams(self, scale=True, fused=False): + return { + 'center': True, + 'scale': scale, + 'decay': 1.0 - 0.003, + 'fused': fused + } + + def _BatchNormMultiplierName(self, scope, has_scaling, fused): + if has_scaling: + if fused: + return scope + '/mul' + return scope + '/BatchNorm/batchnorm/mul' + return scope + '/BatchNorm/batchnorm/Rsqrt' + + def _BathNormBiasName(self, scope, fused): + if fused: + return scope + '/bias' + return scope + '/BatchNorm/batchnorm/sub' + + def _WeightInit(self, stddev): + """Returns a truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initializer that initializes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + def _AssertInputOpsAre(self, op, in_op_names): + """Asserts that all inputs to op come from in_op_names (disregarding order). + + Args: + op: Operation to check inputs for. + in_op_names: List of strings, operations where all op's inputs should + come from. + """ + expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] + self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) + + def _AssertOutputGoesToOps(self, op, graph, out_op_names): + """Asserts that outputs from op go to out_op_names (and perhaps others). + + Args: + op: Operation to check outputs for. + graph: Graph where output operations are located. + out_op_names: List of strings, operations where op's outputs should go. + """ + for out_op_name in out_op_names: + out_op = graph.get_operation_by_name(out_op_name) + self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/graph_matcher.py b/tensorflow/contrib/quantize/python/graph_matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..e3581cc55905a0af7d0464bc0ec673d3ed7f0363 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher.py @@ -0,0 +1,200 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utilities that match patterns in a tf.Graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class OpTypePattern(object): + """A tree pattern that matches TF expressions with certain op types.""" + + def __init__(self, op_type, name=None, inputs=None): + """Initializes an OpTypePattern. + + Args: + op_type: string that specifies the allowed types of the root. It can be + (1) an op type, e.g. 'Conv2D', + (2) '*', i.e. wildcard, or + (3) multiple op types separated by '|', e.g., 'Relu|Relu6'. + We could use regex strings, which might be worthwhile when we have many + similar TF op types. + name: Optional string. The name of the pattern that can be looked up in + MatchResult. + inputs: Optional list of `OpTypePattern`s or strings that specify the + patterns for the inputs of a matching op. If None, this pattern accepts + any inputs of a matching op. + """ + self._op_type = op_type + self._name = name + if inputs is None: + inputs = [] + self._inputs = [ + input_pattern if isinstance(input_pattern, OpTypePattern) else + OpTypePattern(input_pattern) for input_pattern in inputs + ] + + @property + def op_type(self): + return self._op_type + + @property + def inputs(self): + return self._inputs + + @property + def name(self): + return self._name + + +class MatchResult(object): + r"""Encapsulates the result of a match done by GraphMatcher. + + MatchResult contains a map from OpTypePattern to the matching op and tensor. + When the matching op has multiple output tensors, the matching tensor is the + output tensor used by the matching op of the parent pattern. E.g., when we + match graph + + - + + / \y0 y1/ \ + x split z + | + y (nodes are ops; edges are going up) + + against add_pattern defined as + + y1_pattern = OpTypePattern('*') + z_pattern = OpTypePattern('*') + add_pattern = OpTypePattern('+', inputs=[y1_pattern, z_pattern]) + + the matching op of `y1_pattern` is `split`, and the matching tensor of + `y1_pattern` + is `y1` not `y0`. + """ + + def __init__(self): + self._pattern_to_op_tensor = {} + self._name_to_pattern = {} + + def add(self, pattern, op, tensor): + self._pattern_to_op_tensor[pattern] = op, tensor + if pattern.name is not None: + if pattern.name in self._name_to_pattern: + raise ValueError( + 'Name %s is already bound to another pattern' % pattern.name) + self._name_to_pattern[pattern.name] = pattern + + def _to_pattern(self, pattern_or_name): + if isinstance(pattern_or_name, OpTypePattern): + return pattern_or_name + + if isinstance(pattern_or_name, str): + return self._name_to_pattern[pattern_or_name] + + raise ValueError('pattern_or_name has type %s. Expect OpTypePattern or str.' + % type(pattern_or_name)) + + def get_op(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][0] + + def get_tensor(self, pattern_or_name): + return self._pattern_to_op_tensor[self._to_pattern(pattern_or_name)][1] + + +class GraphMatcher(object): + """Checks if a particular subgraph matches a given pattern.""" + + def __init__(self, pattern): + """Initializes a GraphMatcher. + + Args: + pattern: The `OpTypePattern` against which `GraphMatcher` matches + subgraphs. + """ + self._pattern = pattern + + def _match_pattern(self, pattern, op, tensor): + """Returns whether an TF expression rooted at `op` matches `pattern`. + + If there is a match, adds to `self._match_result` the matching op and tensor + with key `pattern`. + + Args: + pattern: An `OpTypePattern`. + op: A `tf.Operation` to match against the pattern. + tensor: the output `tf.Tensor` of `op` that is used by the matching op of + `pattern`'s parent. Can be None if `pattern` is already the root of the + pattern tree. + + Returns: + True if an TF expression rooted at `op` matches `pattern`. + """ + if pattern.op_type != '*': + if op.type not in pattern.op_type.split('|'): + return False + + self._match_result.add(pattern, op, tensor) + + if not pattern.inputs: + # If pattern.inputs is empty, skips the rest and accepts all the inputs. + return True + + return len(op.inputs) == len(pattern.inputs) and all([ + self._match_pattern(input_pattern, input_tensor.op, input_tensor) + for input_tensor, input_pattern in zip(op.inputs, pattern.inputs) + ]) + + def match_op(self, op): + """Matches `op` against `self._pattern`. + + Args: + op: `tf.Operation` to match against the pattern. + + Returns: + Returns a `MatchResult` if `op` matches the pattern; otherwise, returns + None. + """ + self._match_result = MatchResult() + if not self._match_pattern(self._pattern, op, tensor=None): + return None + return self._match_result + + def match_ops(self, ops): + """Matches each operation in `ops` against `self._pattern`. + + Args: + ops: collection of `tf.Operation` to match against the pattern. + + Yields: + `MatchResult` for each `tf.Operation` that matches the pattern. + """ + for op in ops: + match_result = self.match_op(op) + if match_result: + yield match_result + + def match_graph(self, graph): + """Matches each operation in `graph` against `self._pattern`. + + Args: + graph: `tf.Graph` containing operations to match. + + Yields: + `MatchResult` for each `tf.Operation` in `graph` that matches the pattern. + """ + # Python 3.3.2+ implements `yield from`, but for now: + for match_result in self.match_ops(graph.get_operations()): + yield match_result diff --git a/tensorflow/contrib/quantize/python/graph_matcher_test.py b/tensorflow/contrib/quantize/python/graph_matcher_test.py new file mode 100644 index 0000000000000000000000000000000000000000..e1572865e423e569ee3b280036c0e02b71b70648 --- /dev/null +++ b/tensorflow/contrib/quantize/python/graph_matcher_test.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for graph_matcher.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python import ops as contrib_ops +from tensorflow.contrib.layers.python.layers import initializers +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import graph_matcher +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class GraphMatcherTest(test_util.TensorFlowTestCase): + + def test_conv_layer(self): + g = ops.Graph() + with g.as_default(): + inputs = array_ops.placeholder(dtypes.float32, shape=[8, 5, 5, 3]) + + with contrib_ops.arg_scope( + [layers.batch_norm], fused=True, is_training=True, trainable=True): + return layers.convolution( + inputs, + num_outputs=16, + kernel_size=3, + stride=1, + padding='VALID', + activation_fn=nn_ops.relu, + normalizer_fn=layers.batch_norm, + normalizer_params={}, + weights_initializer=initializers.xavier_initializer(), + weights_regularizer=None, + biases_initializer=init_ops.zeros_initializer(), + biases_regularizer=None, + reuse=None, + trainable=True, + scope=None) + + inputs_pattern = graph_matcher.OpTypePattern('*', name='inputs') + relu_pattern = graph_matcher.OpTypePattern( + 'Relu', + name='relu', + inputs=[ + graph_matcher.OpTypePattern( + 'FusedBatchNorm', + inputs=[ + graph_matcher.OpTypePattern( + 'Conv2D', inputs=[inputs_pattern, '*']), '*', '*', '*', + '*' + ]) + ]) + matcher = graph_matcher.GraphMatcher(relu_pattern) + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + self.assertEqual(match_result.get_tensor(inputs_pattern), inputs) + self.assertEqual(match_result.get_tensor('inputs'), inputs) + + def test_multiple_outputs(self): + # - + + # / \y0 y1/ \ + # x split z + # | + # y (nodes are ops; edges are going up) + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[1], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[2], name='y') + y0, y1 = array_ops.split(y, num_or_size_splits=2, axis=0) + z = array_ops.placeholder(dtypes.float32, shape=[1], name='z') + math_ops.add(x, y0) + math_ops.subtract(y1, z) + + y1_pattern = graph_matcher.OpTypePattern('*') + minus_pattern = graph_matcher.OpTypePattern('Sub', inputs=[y1_pattern, '*']) + matcher = graph_matcher.GraphMatcher(minus_pattern) + + match_results = list(matcher.match_graph(g)) + self.assertEqual(1, len(match_results)) + match_result = match_results[0] + + self.assertEqual(y0.op, y1.op) + self.assertEqual(match_result.get_op(y1_pattern), y1.op) + self.assertEqual(match_result.get_tensor(y1_pattern), y1) + + def test_oneof_pattern(self): + # - + + # / \ / \ + # x y z + g = ops.Graph() + with g.as_default(): + x = array_ops.placeholder(dtypes.float32, shape=[], name='x') + y = array_ops.placeholder(dtypes.float32, shape=[], name='y') + z = array_ops.placeholder(dtypes.float32, shape=[], name='z') + plus = x + y + minus = y - z + + add_or_sub_pattern = graph_matcher.OpTypePattern( + 'Add|Sub', inputs=['*', '*']) + matcher = graph_matcher.GraphMatcher(add_or_sub_pattern) + self.assertEqual([ + match_result.get_op(add_or_sub_pattern) + for match_result in matcher.match_graph(g) + ], [plus.op, minus.op]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/input_to_ops.py b/tensorflow/contrib/quantize/python/input_to_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..98755607771ff725023fdf1abbcad8e95e851e23 --- /dev/null +++ b/tensorflow/contrib/quantize/python/input_to_ops.py @@ -0,0 +1,61 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to update a Tensorflow model graph with quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +from tensorflow.contrib.quantize.python import common + + +class InputToOps(object): + """Holds a mapping from tensor's name to ops that take it as input.""" + + def __init__(self, graph): + """Initializes mapping from tensor's name to ops that take it. + + Helps find edges between ops faster and avoids iterating over the whole + graph. The mapping is of type Dict[str, Set[tf.Operation]]. + + Note: while inserting operations into the graph, we do not update the + mapping, assuming that insertion points in the graph are never adjacent. + With that restriction, an out of date mapping still works fine. + + Args: + graph: Graph to process. + """ + self.mapping = collections.defaultdict(set) + for op in (op for op in graph.get_operations()): + if op.name.startswith(common.SKIPPED_PREFIXES): + continue + for op_input in op.inputs: + self.mapping[op_input].add(op) + + def ConsumerOperations(self, producer_op): + """Looks through outputs of producer_op, finds ops that take them as input. + + Args: + producer_op: Operation containing outputs to process. + + Returns: + A Set[Operation] containing all operations taking input from producer_op + outputs. + """ + result = set() + for inp in producer_op.outputs: + result.update(self.mapping[inp]) + return result diff --git a/tensorflow/contrib/quantize/python/input_to_ops_test.py b/tensorflow/contrib/quantize/python/input_to_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..9dbd1eb711831558b94a2c5793311d5c3e85963e --- /dev/null +++ b/tensorflow/contrib/quantize/python/input_to_ops_test.py @@ -0,0 +1,68 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for InputToOps class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + + +class InputToOpsTest(test_util.TensorFlowTestCase): + + def testNoConsumerOperations(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(0, len(consumer_operations)) + + def testOneConsumerOperation(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + output_tensor = nn_ops.relu6(input_tensor) + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(consumer_operations, {output_tensor.op}) + + def testSeveralConsumerOperations(self): + graph = ops.Graph() + with graph.as_default(): + input_tensor = array_ops.zeros((1, 2, 3, 4)) + output_tensor_1 = nn_ops.relu6(input_tensor) + output_tensor_2 = input_tensor + output_tensor_1 + output_tensor_3 = input_tensor * output_tensor_2 + + input_to_ops_map = input_to_ops.InputToOps(graph) + consumer_operations = input_to_ops_map.ConsumerOperations(input_tensor.op) + + self.assertEqual(consumer_operations, + {output_tensor_1.op, output_tensor_2.op, + output_tensor_3.op}) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0a38ef9fcd6f1699b0feee6d439ba69413e0899b --- /dev/null +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -0,0 +1,320 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Python support for quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import add_arg_scope +from tensorflow.contrib.framework.python.ops import model_variable +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.training import moving_averages + +EPSILON = 1e-5 + + +@add_arg_scope +def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): + """Adds a fake quantize layer with fixed quantization interval. + + Args: + inputs: a tensor containing values to be quantized. + init_min: the lower end of quantization interval. + init_max: the upper end of quantization interval. + scope: Optional scope for name_scope. + Returns: + a tensor containing quantized values. + """ + with ops.name_scope(scope, 'FixedQuantize', values=[inputs]): + return array_ops.fake_quant_with_min_max_args( + inputs, min=init_min, max=init_max) + + +@add_arg_scope +def LastValueQuantize(inputs, + per_channel=False, + init_min=-6.0, + init_max=6.0, + updates_collection=ops.GraphKeys.UPDATE_OPS, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + scope=None, + reuse=None, + is_training=True, + num_bits=8, + narrow_range=False): + """Adds a layer that collects quantization ranges as last input ranges. + + LastValueQuantize creates variables called 'min' and 'max', representing the + interval used for quantization and clamping. + + Args: + inputs: a tensor containing values to be quantized. + per_channel: (Optional) a boolean specifying whether to use different + quantization ranges per output channel. + init_min: a float scalar, the initial value for variable min. + init_max: a float scalar, the initial value for variable max. + updates_collection: (Optional) collections to collect the update ops for + computation. + vars_collection: (Optional) collection where to store variables for + quantization interval ends. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + is_training: Whether the op is applied to a training or eval graph. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + with variable_scope.variable_scope( + scope, 'LastValueQuantize', values=[inputs], reuse=reuse): + input_shape = inputs.get_shape() + input_dim = len(input_shape) + if per_channel: + # Only support quantizing 1-, 2- and 4-dimensional tensors. + assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' + ' scope: %s' % (input_shape, scope)) + min_max_shape = [input_shape[-1]] + else: + min_max_shape = [] + + min_var = model_variable( + 'min', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_min), + collections=[vars_collection], + trainable=False) + max_var = model_variable( + 'max', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_max), + collections=[vars_collection], + trainable=False) + if not is_training: + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + if per_channel: + if input_dim == 2: + reduce_dims = [0] + elif input_dim == 4: + reduce_dims = [0, 1, 2] + + if per_channel: + if input_dim >= 2: + batch_min = math_ops.reduce_min( + inputs, reduction_indices=reduce_dims, name='BatchMin') + else: + batch_min = inputs + else: + batch_min = math_ops.reduce_min(inputs, name='BatchMin') + batch_min -= EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. + batch_min = math_ops.minimum(batch_min, 0.0) + assign_min_op = state_ops.assign( + min_var, batch_min, name='AssignMinLast').op + ops.add_to_collection(updates_collection, assign_min_op) + + if per_channel: + if input_dim >= 2: + batch_max = math_ops.reduce_max( + inputs, reduction_indices=reduce_dims, name='BatchMax') + else: + batch_max = inputs + else: + batch_max = math_ops.reduce_max(inputs, name='BatchMax') + batch_max += EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. + batch_max = math_ops.maximum(batch_max, 0.0) + assign_max_op = state_ops.assign( + max_var, batch_max, name='AssignMaxLast').op + ops.add_to_collection(updates_collection, assign_max_op) + + return _FakeQuantWithMinMaxVars( + inputs, + batch_min, + batch_max, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + +@add_arg_scope +def MovingAvgQuantize(inputs, + per_channel=False, + init_min=-6.0, + init_max=6.0, + ema_decay=0.999, + updates_collection=ops.GraphKeys.UPDATE_OPS, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + scope=None, + reuse=None, + is_training=True, + num_bits=8, + narrow_range=False): + """Adds a layer that collects quantization ranges as EMAs of input ranges. + + MovingAvgQuantize creates variables called 'min' and 'max', representing the + interval used for quantization and clamping. + + Args: + inputs: a tensor containing values to be quantized. + per_channel: (default False) a boolean specifying whether to use different + quantization ranges per output channel. + init_min: a float scalar, the initial value for variable min. + init_max: a float scalar, the initial value for variable max. + ema_decay: EMA decay parameter. + updates_collection: (Optional) collections to collect the update ops for + computation. + vars_collection: (Optional) collection where to store variables for + quantization interval ends. + scope: Optional scope for variable_scope. + reuse: whether or not the layer and its variables should be reused. To be + able to reuse the layer scope must be given. + is_training: Whether the op is applied to a training or eval graph. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + with variable_scope.variable_scope( + scope, 'MovingAvgQuantize', values=[inputs], reuse=reuse): + input_shape = inputs.get_shape() + input_dim = len(input_shape) + if per_channel: + # Only support quantizing 1-, 2- and 4-dimensional tensors. + assert input_dim in [1, 2, 4], ('Expected 1D, 2D or 4D input, was: %s in ' + ' scope: %s' % (input_shape, scope)) + min_max_shape = [input_shape[-1]] + else: + min_max_shape = [] + + min_var = model_variable( + 'min', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_min), + collections=[vars_collection], + trainable=False) + max_var = model_variable( + 'max', + shape=min_max_shape, + initializer=init_ops.constant_initializer(init_max), + collections=[vars_collection], + trainable=False) + if not is_training: + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + if per_channel: + if input_dim == 2: + reduce_dims = [0] + elif input_dim == 4: + reduce_dims = [0, 1, 2] + + if per_channel: + if input_dim >= 2: + batch_min = math_ops.reduce_min( + inputs, reduction_indices=reduce_dims, name='BatchMin') + else: + batch_min = inputs + else: + batch_min = math_ops.reduce_min(inputs, name='BatchMin') + # B-eng requires that 0.0 if always in the [min; max] range. + batch_min = math_ops.minimum(batch_min, 0.0) + assign_min_op = moving_averages.assign_moving_average( + min_var, batch_min, ema_decay, name='AssignMinEma').op + ops.add_to_collection(updates_collection, assign_min_op) + + if per_channel: + if input_dim >= 2: + batch_max = math_ops.reduce_max( + inputs, reduction_indices=reduce_dims, name='BatchMax') + else: + batch_max = inputs + else: + batch_max = math_ops.reduce_max(inputs, name='BatchMax') + # B-eng requires that 0.0 if always in the [min; max] range. + batch_max = math_ops.maximum(batch_max, 0.0) + assign_max_op = moving_averages.assign_moving_average( + max_var, batch_max, ema_decay, name='AssignMaxEma').op + ops.add_to_collection(updates_collection, assign_max_op) + + return _FakeQuantWithMinMaxVars( + inputs, + min_var, + max_var, + per_channel=per_channel, + num_bits=num_bits, + narrow_range=narrow_range) + + +def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, + narrow_range): + """Adds a fake quantization operation. + + Depending on value of per_channel, this operation may do global quantization + or per channel quantization. min_var and max_var should have corresponding + shapes: [1] when per_channel == False and [d] when per_channel == True. + + Args: + inputs: a tensor containing values to be quantized. + min_var: a variable containing quantization range lower end(s). + max_var: a variable containing quantization range lupper end(s). + per_channel: a boolean specifying whether to use per-channel quantizatioh. + num_bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^num_bits - 1] or wide range [0; 2^num_bits - 1]. + Returns: + a tensor containing quantized values. + """ + + if per_channel: + assert len(min_var.get_shape()) == 1 + assert len(max_var.get_shape()) == 1 + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars_per_channel( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) + else: + assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison + assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..548e33663e868e71b8b44aa0634b6ebb72e07640 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -0,0 +1,403 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Logic to update a Tensorflow model graph with quantization operations.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re +from tensorflow.contrib import graph_editor +from tensorflow.contrib.quantize.python import common +from tensorflow.contrib.quantize.python import input_to_ops +from tensorflow.contrib.quantize.python import quant_ops +from tensorflow.python.framework import ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.training import training_util + +# Operation types used to select operations of interest. +_QUANTIZABLE_TYPES = {'Conv2D', 'MatMul', 'DepthwiseConv2dNative'} + +# Custom key for storing and retrieving update ops used by quantizing nodes. +_UPDATE_QUANT_OPS = 'update_quant_ops' + + +def Quantize(graph, + weight_bits=8, + weight_narrow_range=False, + activation_bits=8, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + quantize_folded_weights_use_ema=False): + """Updates graph with quantization operations. + + Args: + graph: Graph to modify. + weight_bits: Number of bits to use for quantizing weights. + weight_narrow_range: Whether to use a more efficient narrow range for + weights quantization. With weight_narrow_range true, the range is + [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. + activation_bits: Number of bits to use for quantizing activations. + ema_decay: (Optional) Float, EMA decay parameter. EMA is used to update + quantization intervals for quantizing activations (see here about EMA: + https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average). + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training graph or eval graph. + quantize_folded_weights_use_ema: (Optional, default False) Whether to + quantize weights after batchnorm-folding with exponential average + quantization. + Raises: + ValueError: When quantization fails. + """ + context = _QuantizeContext(graph, weight_bits, weight_narrow_range, + activation_bits, ema_decay, quant_delay, + vars_collection, is_training, + quantize_folded_weights_use_ema) + + graph_ops = graph.get_operations() + + # Filter out backprop and summary related operations, leave only interesting + # op types. + def _IsInterestingOpWithWeights(op): + return (op.type in _QUANTIZABLE_TYPES and + not op.name.startswith(common.SKIPPED_PREFIXES)) + + for op in (op for op in graph_ops if _IsInterestingOpWithWeights(op)): + if op.name.endswith('/depthwise'): + # Separable convolution may consist of 2 convolution nodes. If so, skip + # .../depthwise and only quantize the top one. + separable_conv = context.GetOperationByNameDontThrow( + op.name[:-len('/depthwise')]) + if separable_conv and separable_conv.type == 'Conv2D': + continue + if op.type == 'Conv2D': + # Quantize add ops that come after Conv2D + add_context_re = re.search(r'^(.*)/[^/]+/', op.name) + if add_context_re is not None: + context.add_contexts.add(add_context_re.group(1)) + if not op.name.endswith('_Fold'): + folded_op = context.GetOperationByNameDontThrow(op.name + '_Fold') + # Do nothing if found, it will be quantized when it is iterated over. + if not folded_op: + context.QuantizeOpWithWeights(op, folded=False) + else: + context.QuantizeOpWithWeights(op, folded=True) + + context.QuantizeAddContexts() + + # Once all quantization ops have been inserted in the graph, collect update + # ops for their variables and modify the TF Slim update barrier (see + # https://www.tensorflow.org/code/tensorflow/contrib/slim/python/slim/learning.py) + # to depend on them. + try: + update_barrier = graph.get_operation_by_name('update_barrier') + except KeyError: + # In evaluation graph, this barrier may not exist. + return None + update_quant_ops = graph.get_collection_ref(_UPDATE_QUANT_OPS) + graph_editor.add_control_inputs(update_barrier, update_quant_ops) + + +class _QuantizeContext(object): + """Context holds references needed for quantization.""" + + def __init__(self, + graph, + weight_bits, + weight_narrow_range, + activation_bits, + ema_decay=0.999, + quant_delay=None, + vars_collection=ops.GraphKeys.MOVING_AVERAGE_VARIABLES, + is_training=True, + quantize_folded_weights_use_ema=False): + """Initializes context to hold references needed for quantization. + + Args: + graph: Graph to modify. + weight_bits: Number of bits to use for quantizing weights. + weight_narrow_range: Whether to use a more efficient narrow range for + weights quantization. With weight_narrow_range true, the range is + [1; 2^weight_bits - 1], with it false [0; 2^weight_bits - 1]. + activation_bits: Number of bits to use for quantizing activations. + ema_decay: (Optional) Float, EMA decay parameter. + quant_delay: (Optional, default None) Int, count of global steps for which + to delay quantization. This helps weights stabilize at the start of + training. + vars_collection: (Optional) Collection where to store the variables for + quantization interval ends. + is_training: (Optional) Whether quantizing training or eval graph. + quantize_folded_weights_use_ema: (Optional, default False) Whether to + quantize weights after batchnorm-folding with exponential average + quantization. + """ + self.graph = graph + self.weight_bits = weight_bits + self.weight_narrow_range = weight_narrow_range + self.activation_bits = activation_bits + self.ema_decay = ema_decay + self.quant_delay = quant_delay + self.vars_collection = vars_collection + self.is_training = is_training + self.quantize_folded_weights_use_ema = quantize_folded_weights_use_ema + self.input_to_ops_map = input_to_ops.InputToOps(graph) + self.add_contexts = set() + + def QuantizeAddContexts(self): + """Quantizes all add ops in self.add_contexts.""" + for add_context in self.add_contexts: + add_op = self.GetOperationByNamesDontThrow([ + add_context + '/Add', add_context + '/add']) + if add_op is not None: + self._InsertQuantOp( + add_context, + add_op, + self.input_to_ops_map.ConsumerOperations(add_op), + name='add_quant', + moving_avg=True, + bits=self.activation_bits, + narrow_range=False) + + def QuantizeOpWithWeights(self, op, folded): + """Quantizes around the specific operation with or without batch norm. + + Args: + op: Operation to quantize. + folded: Operation has been folded and needs special handling if True. + Raises: + ValueError: When quantization fails. + """ + # Op name component before the last slash will be used as context. + context = re.search(r'^(.*)/([^/]+)', op.name).group(1) + + # Quantize weights. + if folded: + producer_op = self.graph.get_operation_by_name(context + '/mul_fold') + else: + try: + input_idx = next(i for i, v in enumerate(op.inputs) + if '/weights/' in v.name or + '/depthwise_weights' in v.name) + except StopIteration: + raise ValueError('No inputs to quantize for op: %s' % op) + producer_op = op.inputs[input_idx].op + + # If batch norm is used, the folded weights depend on the batch std, hence + # it is sensible to use EMA during training to smooth out the noise. This is + # controlled by the flag quantize_folded_weights_use_ema. Its default is + # False for backward compatibility. + # If there is no batch norm, weights do not depend on the batch and using + # the latest value of min and max is more efficient. + weight_use_ema = folded and self.quantize_folded_weights_use_ema + self._InsertQuantOp( + context, + producer_op, [op], + name='weights_quant', + moving_avg=weight_use_ema, + delay_requested=weight_use_ema, + bits=self.weight_bits, + narrow_range=self.weight_narrow_range) + + # Important: do not quantize biases here. During inference they are + # quantized to 32 bits, which is much finer than 8 bit quantization and + # depends on weight and input activation ranges. + + # Find activation and (optionally) Add operations to quantize. + activation_op, add_op, add_context = self._GetReluAndAddOperations(context, + op) + if add_op: + original_context = context + context = add_context + + # Quantize activation outputs. + consumer_ops = self.input_to_ops_map.ConsumerOperations(activation_op) + self._InsertQuantOp( + context, + activation_op, + consumer_ops, + name='act_quant', + moving_avg=True, + init_min=0.0, + bits=self.activation_bits, + narrow_range=False) + + # When a bypass connection was found, also quantize Add op input. + if add_op: + def _QuantizeAddInput(add_input): + if folded: + return add_input.op.name.endswith('/add_fold') + else: + return add_input.op.name.startswith(original_context + '/') + + for add_input in add_op.inputs: + if _QuantizeAddInput(add_input): + self._InsertQuantOp( + original_context, + add_input.op, [add_op], + name='conv_quant', + moving_avg=True, + bits=self.activation_bits, + narrow_range=False) + + def _GetReluAndAddOperations(self, context, op): + """Looks up a Relu* and Add operations in given context. + + Args: + context: Context where to look for operations. + op: Operation to quantize. + + Returns: + A triplet (Operation, Operation, string), the first element is an end + point operation, the second is Add operation (optional), the third element + is string context where the Add operation was found (optional). + + Raises: + ValueError: When operations cannot be found. + """ + activation_op = common.GetEndpointActivationOp(self.graph, context) + if activation_op: + return activation_op, None, None + + if '/' in context: + # If no activation op is there, look for them one level up. + add_context = re.search(r'^(.*)/([^/]+)', context).group(1) + activation_op = common.GetEndpointActivationOp(self.graph, add_context) + if not activation_op: + # Still no Relu, can happen on the top layer, just find the next node up, + # make sure it is BiasAdd. + consumers = [c for outp in op.outputs for c in outp.consumers()] + if len(consumers) != 1 or consumers[0].type != 'BiasAdd': + raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) + return consumers[0], None, None + if add_context: + add_op = self.GetOperationByNamesDontThrow([ + add_context + '/Add', add_context + '/add']) + return activation_op, add_op, add_context + else: + raise ValueError('Failed to quantize op: %s, %s' % (op.name, op.type)) + + def GetOperationByNameDontThrow(self, name): + """Returns an Operation with the given name. + + Args: + name: Name of Operation to return. + + Returns: + The Operation with the given name. None if the name does not correspond to + any operation in the graph. + """ + try: + return self.graph.get_operation_by_name(name) + except KeyError: + return None + + def GetOperationByNamesDontThrow(self, names): + """Returns an Operation with one of the given names. + + Args: + names: Names of Operation to return. + + Returns: + The Operation with one of the given names. None if none of the names + corresponds to any operation in the graph. + """ + for name in names: + op = self.GetOperationByNameDontThrow(name) + if op is not None: + return op + return None + + def _InsertQuantOp( + self, + context, + producer, + consumers, + name, + moving_avg=True, + init_min=-6.0, + init_max=6.0, + delay_requested=True, + bits=8, + narrow_range=False,): + """Inserts a quant op between a producer op and (multiple) consumer ops. + + Args: + context: Context where producer and consumer operations are nested. + producer: Producer operation of the pairs where quantization will be + inserted. + consumers: Consumer operations of the pairs. + name: Name for the new quantization op within the context. + moving_avg: Specifies whether to use exponential moving average or just + the last value seen. + init_min: Starting minimum value for the new quantization op. + init_max: Starting maximum value for the new quantization op. + delay_requested: If true, implement quantization delay where needed. + False value explicitly disables delay quantization everywhere. + bits: Number of bits to use for quantization, must be between 2 and 8. + narrow_range: Whether to use the narrow quantization range + [1; 2^bits - 1] or wide range [0; 2^bits - 1]. + Raises: + ValueError: When producer operation is not directly connected to the + consumer operation. + """ + scope = context + '/' + name + inputs = producer.outputs[0] + if moving_avg: + quant = (quant_ops.MovingAvgQuantize( + inputs, + init_min=init_min, + init_max=init_max, + ema_decay=self.ema_decay, + is_training=self.is_training, + num_bits=bits, + narrow_range=narrow_range, + updates_collection=_UPDATE_QUANT_OPS, + vars_collection=self.vars_collection, + scope=scope)) + else: + quant = (quant_ops.LastValueQuantize( + inputs, + init_min=init_min, + init_max=init_max, + is_training=self.is_training, + num_bits=bits, + narrow_range=narrow_range, + updates_collection=_UPDATE_QUANT_OPS, + vars_collection=self.vars_collection, + scope=scope)) + + if delay_requested and self.quant_delay and self.quant_delay > 0: + activate_quant = math_ops.greater_equal( + training_util.get_global_step(), + self.quant_delay, + name=scope + '/activate_quant') + quant = control_flow_ops.cond( + activate_quant, + lambda: quant, + lambda: inputs, + name=scope + '/delayed_quant') + + nodes_modified_count = graph_editor.reroute_ts( + [quant], [inputs], can_modify=consumers) + if nodes_modified_count != len(consumers): + raise ValueError('Some inputs not quantized for ops: [%s]' % + ', '.join([consumer.name for consumer in consumers])) diff --git a/tensorflow/contrib/quantize/python/quantize_graph.py b/tensorflow/contrib/quantize/python/quantize_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..d647bb94e849c713c2aca93c53f372bae5857c43 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_graph.py @@ -0,0 +1,130 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""API to simulate quantization on a python graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.quantize.python import copy_graph +from tensorflow.contrib.quantize.python import fold_batch_norms +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.ops import variables + + +def _create_graph(input_graph, + is_training, + elements=None, + device_name_or_function=None): + """Returns a transformed training input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + is_training: Whether quantizing training or eval graph. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. + + Returns: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements, if elements is not None. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + # TODO(suharshs): Describe the process in more detail in the doc string. + g = copy_graph.CopyGraph(input_graph) + with g.as_default(): + with ops.device(device_name_or_function): + fold_batch_norms.FoldBatchNorms(g) + quantize.Quantize(g, is_training=is_training) + if elements is None: + return g + + return_elements = [] + for element in elements: + if isinstance(element, (ops.Tensor, variables.Variable)): + return_elements.append(g.get_tensor_by_name(element.name)) + elif isinstance(element, ops.Operation): + return_elements.append(g.get_operation_by_name(element.name)) + else: + raise ValueError( + 'elements must consist of Tensor or Operation objects, got: ', + str(element)) + return g, return_elements + + +def create_training_graph(input_graph, + elements=None, + device_name_or_function=None): + """Returns a transformed training input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. + + Returns: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements, if elements is not None. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph( + input_graph=input_graph, + is_training=True, + elements=elements, + device_name_or_function=device_name_or_function) + + +def create_eval_graph(input_graph, elements=None, device_name_or_function=None): + """Returns a transformed eval input_graph for simulated quantization. + + The forward pass has fake quantization ops inserted to simulate the error + introduced by quantization. + + Args: + input_graph: The tf.Graph to be transformed. + elements: (Optional) List of Tensors and Operations in input_graph whose + corresponding elements in the new graph will be returned. + device_name_or_function: (Optional) The device name or function to use. + + Returns: + g is new tf.Graph that is rewritten for simulated quantization. + l is a list of Tensors/Operations in g corresponding to the provided input + elements, if elements is not None. + + Raises: + ValueError: If elements contains an element that isn't a tf.Tensor or + tf.Operation. + """ + return _create_graph( + input_graph=input_graph, + is_training=False, + elements=elements, + device_name_or_function=device_name_or_function) diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3407ace3914fe2de2506a2952ea5d1bf19028bb9 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py @@ -0,0 +1,141 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for the quantize_graph graph rewriting API.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import quantize_graph +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import googletest + + +class QuantizeGraphTest(test_util.TensorFlowTestCase): + + # We have a lot of other tests that test the details of the rewrite, here we + # just the specific features of the quantize_graph API. + def testReturnedElementsTraining(self): + self._TestReturnElements(True) + + def testReturnedElementsEval(self): + self._TestReturnElements(False) + + def _TestReturnElements(self, is_training): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + c = a + b + elements = [a, b, c.op] + if is_training: + q_graph, returned_elements = quantize_graph.create_training_graph( + graph, elements=elements) + else: + q_graph, returned_elements = quantize_graph.create_eval_graph( + graph, elements=elements) + # Make sure q_graph is different from graph. + self.assertTrue(graph != q_graph) + # Check that the returned elements are part of the new graph. + for returned_element in returned_elements: + self.assertEqual(q_graph, returned_element.graph) + # Check that the elements match with the one from the input graph. + for element, returned_element in zip(elements, returned_elements): + self.assertEqual(element.name, returned_element.name) + + def testNoReturnElementsTraining(self): + self._TestNoReturnElements(True) + + def testNoReturnElementsEval(self): + self._TestNoReturnElements(False) + + def _TestNoReturnElements(self, is_training): + graph = ops.Graph() + with graph.as_default(): + a = constant_op.constant(1.0) + b = variables.Variable(2.0) + _ = a + b + if is_training: + q_graph = quantize_graph.create_training_graph(graph) + else: + q_graph = quantize_graph.create_eval_graph(graph) + # Check that quantize_graph didn't return a tuple when elements isn't + # provided. + self.assertTrue(isinstance(q_graph, ops.Graph)) + # Make sure q_graph is different from graph. + self.assertTrue(graph != q_graph) + + def testDeviceNameTraining(self): + self._TestDeviceName(True) + + def testDeviceNameEval(self): + self._TestDeviceName(False) + + def _TestDeviceName(self, is_training): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + conv = layers.conv2d( + inputs, + 32, [5, 5], + stride=2, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + scope='test') + _ = nn_ops.relu6(conv) + + device_name = '/job:oink/task:0/device:CPU:0' + if is_training: + q_graph = quantize_graph.create_training_graph( + graph, device_name_or_function=device_name) + else: + q_graph = quantize_graph.create_eval_graph( + graph, device_name_or_function=device_name) + + orig_variable_names = set( + [v.name for v in graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)]) + q_variables = q_graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES) + # Ensure that variables were added. + self.assertTrue(len(orig_variable_names) < len(q_variables)) + # All added variables should have the specified device name. + for var in q_variables: + if var.name not in orig_variable_names: + self.assertEqual(var.device, device_name) + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_parameterized_test.py b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3e62f95bd63db3134ba0b96c46b4a92aa73ebef9 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_parameterized_test.py @@ -0,0 +1,717 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Parameterized unit tests for quantizing a Tensorflow graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import fold_batch_norms +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest +from tensorflow.python.training import training + +batch_norm = layers.batch_norm +conv2d = layers.conv2d +fully_connected = layers.fully_connected +separable_conv2d = layers.separable_conv2d + + +class QuantizeTest(test_util.TensorFlowTestCase): + + def _RunWithoutBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay) + (nn_ops.relu6, 'Relu6', False, None), + (nn_ops.relu, 'Relu', False, None), + (array_ops.identity, 'Identity', False, None), + (nn_ops.relu6, 'Relu6', False, 5000), + (nn_ops.relu, 'Relu', False, 5000), + (array_ops.identity, 'Identity', False, 5000), + (nn_ops.relu6, 'Relu6', True, None), + (nn_ops.relu, 'Relu', True, None), + (array_ops.identity, 'Identity', True, None), + (nn_ops.relu6, 'Relu6', True, 5000), + (nn_ops.relu, 'Relu', True, 5000), + (array_ops.identity, 'Identity', True, 5000), + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3]) + + def _TestQuantize_Conv2dWithoutBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> Conv2d no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + out_depth = 3 if with_bypass else 32 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = conv2d(inputs, out_depth, [5, 5], stride=stride, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/Conv2D' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_Conv2dWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_Conv2dWithoutBatchNorm) + + def _TestQuantize_FCWithoutBatchNorm(self, activation, activation_op_name, + with_bypass, delay): + """Tests quantization: inputs -> FC no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected(inputs, out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/MatMul' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_FCWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_FCWithoutBatchNorm) + + def _TestQuantize_DepthwiseConv2dWithoutBatchNorm( + self, activation, activation_op_name, with_bypass, delay): + """Tests quantization: inputs -> DWConv2d no batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + activation_fn = None if with_bypass else activation + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d(inputs, None, [5, 5], stride=stride, + depth_multiplier=1.0, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=activation_fn, scope=scope) + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + node = activation(node, name='test/' + activation_op_name) + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph, quant_delay=delay) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/Minimum', scope + '/weights_quant/Maximum', + scope + '/depthwise_weights/read' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + '/depthwise' + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/BiasAdd' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def testQuantize_DepthwiseConv2dWithoutBatchNorm(self): + self._RunWithoutBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithoutBatchNorm) + + def _RunBatchNormTestOverParameters(self, test_fn): + # TODO(suharshs): Use parameterized test once OSS TF supports it. + parameters_list = [ + # (activation, activation_op_name, with_bypass, delay, fused_batch_norm) + (nn_ops.relu6, 'Relu6', False, None, False), + (nn_ops.relu, 'Relu', False, None, False), + (array_ops.identity, 'Identity', False, None, False), + (nn_ops.relu6, 'Relu6', False, 5000, False), + (nn_ops.relu, 'Relu', False, 5000, False), + (array_ops.identity, 'Identity', False, 5000, False), + (nn_ops.relu6, 'Relu6', True, None, False), + (nn_ops.relu, 'Relu', True, None, False), + (array_ops.identity, 'Identity', True, None, False), + (nn_ops.relu6, 'Relu6', True, 5000, False), + (nn_ops.relu, 'Relu', True, 5000, False), + (array_ops.identity, 'Identity', True, 5000, False), + (nn_ops.relu6, 'Relu6', False, None, True), + (nn_ops.relu, 'Relu', False, None, True), + (array_ops.identity, 'Identity', False, None, True), + (nn_ops.relu6, 'Relu6', False, 5000, True), + (nn_ops.relu, 'Relu', False, 5000, True), + (array_ops.identity, 'Identity', False, 5000, True), + (nn_ops.relu6, 'Relu6', True, None, True), + (nn_ops.relu, 'Relu', True, None, True), + (array_ops.identity, 'Identity', True, None, True), + (nn_ops.relu6, 'Relu6', True, 5000, True), + (nn_ops.relu, 'Relu', True, 5000, True), + (array_ops.identity, 'Identity', True, 5000, True) + ] + for params in parameters_list: + test_fn(params[0], params[1], params[2], params[3], params[4]) + + def _TestQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): + """Tests quantization: inputs -> Conv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + """ + self._testQuantize_Conv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=True) + self._testQuantize_Conv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=False) + + def testQuantize_Conv2dWithBatchNorm(self): + self._RunBatchNormTestOverParameters(self._TestQuantize_Conv2dWithBatchNorm) + + def _testQuantize_Conv2dWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm, + use_ema): + """Tests quantization: inputs -> Conv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + out_depth = 3 if with_bypass else 32 + scope = 'test/test2' if with_bypass else 'test' + node = conv2d( + inputs, + out_depth, [5, 5], + stride=stride, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph) + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if (delay and use_ema) else '/Conv2D_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _TestQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm): + """Tests quantization: inputs -> FC with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + """ + self._testQuantize_FCWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=True) + self._testQuantize_FCWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=False) + + def testQuantize_FCWithBatchNorm(self): + self._RunBatchNormTestOverParameters(self._TestQuantize_FCWithBatchNorm) + + def _testQuantize_FCWithBatchNorm(self, activation, activation_op_name, + with_bypass, delay, fused_batch_norm, + use_ema): + """Tests quantization: inputs -> FC with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, depth = 5, 256 + inputs = array_ops.zeros((batch_size, depth)) + out_depth = 256 if with_bypass else 128 + scope = 'test/test2' if with_bypass else 'test' + node = fully_connected( + inputs, + out_depth, + weights_initializer=self._WeightInit(0.03), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph) + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay and use_ema else '/MatMul_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _TestQuantize_DepthwiseConv2dWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm): + """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + """ + self._testQuantize_DepthwiseConv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=True) + self._testQuantize_DepthwiseConv2dWithBatchNorm( + activation, + activation_op_name, + with_bypass, + delay, + fused_batch_norm, + use_ema=False) + + def testQuantize_DepthwiseConv2dWithBatchNorm(self): + self._RunBatchNormTestOverParameters( + self._TestQuantize_DepthwiseConv2dWithBatchNorm) + + def _testQuantize_DepthwiseConv2dWithBatchNorm( + self, activation, activation_op_name, with_bypass, delay, + fused_batch_norm, use_ema): + """Tests quantization: inputs -> DWConv2d with batch norm -> Activation. + + Args: + activation: Callable that returns an Operation, a factory method for the + Activation. + activation_op_name: String, name of the Activation operation. + with_bypass: Bool, when true there is an extra connection added from + inputs to just before Activation. + delay: Int (optional), delay in number of steps until quantization starts. + fused_batch_norm: Bool, when true use FusedBatchNorm. + use_ema: Bool, when true uses EMA quantization for BN folded weights. + """ + graph = ops.Graph() + with graph.as_default(): + training.create_global_step(graph) + + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + stride = 1 if with_bypass else 2 + scope = 'test/test2' if with_bypass else 'test' + node = separable_conv2d( + inputs, + None, [5, 5], + stride=stride, + depth_multiplier=1.0, + padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, + normalizer_fn=batch_norm, + normalizer_params=self._BatchNormParams(fused_batch_norm), + scope=scope) + + # Manually add a bypass (optionaly) and an activation. + if with_bypass: + node = math_ops.add(inputs, node, name='test/Add') + + node = activation(node, name='test/' + activation_op_name) + + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + fold_batch_norms.FoldBatchNorms(graph) + + quantize.Quantize( + graph, quant_delay=delay, quantize_folded_weights_use_ema=use_ema) + quantization_node_name = 'FakeQuantWithMinMaxVars' + weights_quant = graph.get_operation_by_name(scope + '/weights_quant/' + + quantization_node_name) + self.assertEqual(weights_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/weights_quant/' + ('min/read' if use_ema else 'Minimum'), + scope + '/weights_quant/' + ('max/read' if use_ema else 'Maximum'), + scope + '/mul_fold' + ] + self._AssertInputOpsAre(weights_quant, expected_inputs) + output_op_name = scope + ('/weights_quant/delayed_quant/Switch_1' + if delay and use_ema else '/depthwise_Fold') + self._AssertOutputGoesToOps(weights_quant, graph, [output_op_name]) + + if with_bypass: + conv_quant = graph.get_operation_by_name(scope + '/conv_quant/' + + quantization_node_name) + self.assertEqual(conv_quant.type, quantization_node_name) + expected_inputs = [ + scope + '/conv_quant/min/read', scope + '/conv_quant/max/read', + scope + '/add_fold' + ] + self._AssertInputOpsAre(conv_quant, expected_inputs) + output_op_name = (scope + '/conv_quant/delayed_quant/Switch_1' + if delay else 'test/Add') + self._AssertOutputGoesToOps(conv_quant, graph, [output_op_name]) + + act_quant = graph.get_operation_by_name('test/act_quant/' + + quantization_node_name) + self.assertEqual(act_quant.type, quantization_node_name) + expected_inputs = [ + 'test/act_quant/min/read', 'test/act_quant/max/read', + 'test/' + activation_op_name + ] + self._AssertInputOpsAre(act_quant, expected_inputs) + output_op_name = ('test/act_quant/delayed_quant/Switch_1' + if delay else 'control_dependency') + self._AssertOutputGoesToOps(act_quant, graph, [output_op_name]) + + def _BatchNormParams(self, fused=False): + return {'center': True, 'scale': True, 'decay': 1.0 - 0.003, 'fused': fused} + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + + def _AssertInputOpsAre(self, op, in_op_names): + """Asserts that all inputs to op come from in_op_names (disregarding order). + + Args: + op: Operation to check inputs for. + in_op_names: List of strings, operations where all op's inputs should + come from. + """ + expected_inputs = [in_op_name + ':0' for in_op_name in in_op_names] + self.assertItemsEqual([t.name for t in op.inputs], expected_inputs) + + def _AssertOutputGoesToOps(self, op, graph, out_op_names): + """Asserts that outputs from op go to out_op_names (and perhaps others). + + Args: + op: Operation to check outputs for. + graph: Graph where output operations are located. + out_op_names: List of strings, operations where op's outputs should go. + """ + for out_op_name in out_op_names: + out_op = graph.get_operation_by_name(out_op_name) + self.assertIn(op.outputs[0].name, [str(t.name) for t in out_op.inputs]) + + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py new file mode 100644 index 0000000000000000000000000000000000000000..eb141a21bd8eb21b5b7e56a393d6c8016b5b1e94 --- /dev/null +++ b/tensorflow/contrib/quantize/python/quantize_test.py @@ -0,0 +1,94 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Unit tests for quantizing a Tensorflow graph.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.quantize.python import quantize +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops +from tensorflow.python.platform import googletest + +conv2d = layers.conv2d + + +class QuantizeTest(test_util.TensorFlowTestCase): + + def testInsertQuantOpFailsWhenOpsNotConnected(self): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + inputs = array_ops.zeros((batch_size, height, width, depth)) + conv = conv2d(inputs, 32, [5, 5], stride=2, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, scope='test') + relu = nn_ops.relu6(inputs) + + context = quantize._QuantizeContext(graph=graph, weight_bits=8, + weight_narrow_range=True, + activation_bits=8) + # Inserting a quantization op between two unconnected ops should fail with + # ValueError. + with self.assertRaises(ValueError) as err: + context._InsertQuantOp('test', conv.op, [relu.op], 'FailingQuantOp') + self.assertEqual( + str(err.exception), 'Some inputs not quantized for ops: [Relu6]') + + def testInsertQuantOpForAddAfterConv2d(self): + graph = ops.Graph() + with graph.as_default(): + batch_size, height, width, depth = 5, 128, 128, 3 + input1 = array_ops.zeros((batch_size, height, width, depth)) + input2 = array_ops.zeros((batch_size, height / 2, width / 2, 32)) + conv = conv2d(input1, 32, [5, 5], stride=2, padding='SAME', + weights_initializer=self._WeightInit(0.09), + activation_fn=None, scope='test/test') + node = math_ops.add(conv, input2, name='test/add') + node = array_ops.identity(node, name='test/identity') + update_barrier = control_flow_ops.no_op(name='update_barrier') + with ops.control_dependencies([update_barrier]): + array_ops.identity(node, name='control_dependency') + + quantize.Quantize(graph=graph, weight_bits=8, weight_narrow_range=True, + activation_bits=8) + + quantization_node_name = 'FakeQuantWithMinMaxVars' + add_quant = graph.get_operation_by_name('test/add_quant/' + + quantization_node_name) + self.assertEqual(add_quant.type, quantization_node_name) + + def _WeightInit(self, stddev): + """Returns truncated normal variable initializer. + + Function is defined purely to shorten the name so that it stops wrapping. + + Args: + stddev: Standard deviation of normal variable. + + Returns: + An initialized that initialzes with a truncated normal variable. + """ + return init_ops.truncated_normal_initializer(stddev=stddev) + +if __name__ == '__main__': + googletest.main() diff --git a/tensorflow/contrib/receptive_field/BUILD b/tensorflow/contrib/receptive_field/BUILD index ed2f3af08cbbd8ae5da2a87f4a7dd9854493c346..d16b2908a0285e04ef5d3ede2050bf24c508228d 100644 --- a/tensorflow/contrib/receptive_field/BUILD +++ b/tensorflow/contrib/receptive_field/BUILD @@ -39,7 +39,9 @@ py_library( deps = [ ":graph_compute_order_py", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//third_party/py/numpy", ], ) @@ -49,12 +51,13 @@ py_test( srcs_version = "PY2AND3", deps = [ ":receptive_field_py", - "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/slim", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:nn", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field.py b/tensorflow/contrib/receptive_field/python/util/receptive_field.py index db190a1a41668bff3f6db1c674192980db068838..8b34465d21d14508c24056b588f2533d8fea6a1d 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field.py @@ -27,13 +27,15 @@ import math from tensorflow.contrib.receptive_field.python.util import graph_compute_order from tensorflow.contrib.util import make_ndarray from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.framework import ops as framework_ops +import numpy as np # White-listed layer operations, which do not affect the receptive field # computation. _UNCHANGED_RF_LAYER_OPS = [ - "Softplus", "Relu", "BiasAdd", "Mul", "Add", "Const", "Identity", - "VariableV2", "Sub", "Rsqrt", "ConcatV2" -] + 'Add', 'BiasAdd', 'Ceil', 'ConcatV2', 'Const', 'Floor', 'Identity', 'Log', + 'Mul', 'Pow', 'RealDiv', 'Relu', 'Round', 'Rsqrt', 'Softplus', 'Sub', + 'VariableV2'] # Different ways in which padding modes may be spelled. _VALID_PADDING = ["VALID", b"VALID"] @@ -238,7 +240,8 @@ def _get_layer_params(node, name_to_order_node): padding_x = 0 padding_y = 0 else: - raise ValueError("Unknown layer op: %s" % node.op) + raise ValueError("Unknown layer for operation '%s': %s" % + (node.name, node.op)) return kernel_size_x, kernel_size_y, stride_x, stride_y, padding_x, padding_y @@ -304,13 +307,103 @@ def _get_effective_padding_node_input(stride, padding, return stride * effective_padding_output + padding -def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): - """Computes receptive field (RF) parameters from a GraphDef object. +class ReceptiveField: + """ + Receptive field of a convolutional neural network. + + Args: + size: Receptive field size. + stride: Effective stride. + padding: Effective padding. + """ + def __init__(self, size, stride, padding): + self.size = np.asarray(size) + self.stride = np.asarray(stride) + self.padding = np.asarray(padding) + + def compute_input_center_coordinates(self, y, axis=None): + """ + Computes the center of the receptive field that generated a feature. + + Args: + y: An array of feature coordinates with shape `(..., d)`, where `d` is the + number of dimensions of the coordinates. + axis: The dimensions for which to compute the input center coordinates. + If `None` (the default), compute the input center coordinates for all + dimensions. + + Returns: + x: Center of the receptive field that generated the features, at the input + of the network. + + Raises: + ValueError: If the number of dimensions of the feature coordinates does + not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + y = np.asarray(y) + if y.shape[-1] != len(axis): + raise ValueError("Dimensionality of the feature coordinates `y` (%d) " + "does not match dimensionality of `axis` (%d)" % + (y.shape[-1], len(axis))) + return - self.padding[axis] + y * self.stride[axis] + \ + (self.size[axis] - 1) / 2 + + def compute_feature_coordinates(self, x, axis=None): + """ + Computes the position of a feature given the center of a receptive field. + + Args: + x: An array of input center coordinates with shape `(..., d)`, where `d` + is the number of dimensions of the coordinates. + axis: The dimensions for which to compute the feature coordinates. + If `None` (the default), compute the feature coordinates for all + dimensions. + + Returns: + y: Coordinates of the features. + + Raises: + ValueError: If the number of dimensions of the input center coordinates + does not match the number of elements in `axis`. + """ + # Use all dimensions. + if axis is None: + axis = range(self.size.size) + # Ensure axis is a list because tuples have different indexing behavior. + axis = list(axis) + x = np.asarray(x) + if x.shape[-1] != len(axis): + raise ValueError("Dimensionality of the input center coordinates `x` " + "(%d) does not match dimensionality of `axis` (%d)" % + (x.shape[-1], len(axis))) + return (x + self.padding[axis] + (1 - self.size[axis]) / 2) / \ + self.stride[axis] + + def __iter__(self): + return iter(np.concatenate([self.size, self.stride, self.padding])) + + +def compute_receptive_field_from_graph_def(graph_def, input_node, output_node, + stop_propagation=None): + """Computes receptive field (RF) parameters from a Graph or GraphDef object. + + The algorithm stops the calculation of the receptive field whenever it + encounters an operation in the list `stop_propagation`. Stopping the + calculation early can be useful to calculate the receptive field of a + subgraph such as a single branch of the + [inception network](https://arxiv.org/abs/1512.00567). Args: - graph_def: GraphDef object. - input_node: Name of the input node from graph. - output_node: Name of the output node from graph. + graph_def: Graph or GraphDef object. + input_node: Name of the input node or Tensor object from graph. + output_node: Name of the output node or Tensor object from graph. + stop_propagation: List of operation or scope names for which to stop the + propagation of the receptive field. Returns: rf_size_x: Receptive field size of network in the horizontal direction, with @@ -331,6 +424,18 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): cannot be found. For network criterion alignment, see photos/vision/features/delf/g3doc/rf_computation.md """ + # Convert a graph to graph_def if necessary. + if isinstance(graph_def, framework_ops.Graph): + graph_def = graph_def.as_graph_def() + + # Convert tensors to names. + if isinstance(input_node, framework_ops.Tensor): + input_node = input_node.op.name + if isinstance(output_node, framework_ops.Tensor): + output_node = output_node.op.name + + stop_propagation = stop_propagation or [] + # Computes order of computation for a given graph. name_to_order_node = graph_compute_order.get_compute_order( graph_def=graph_def) @@ -422,6 +527,10 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): # Loop over this node's inputs and potentially propagate information down. for inp_name in node.input: + # Stop the propagation of the receptive field. + if any(inp_name.startswith(stop) for stop in stop_propagation): + logging.vlog(3, "Skipping explicitly ignored node %s.", node.name) + continue logging.vlog(4, "inp_name = %s", inp_name) inp_node = name_to_order_node[inp_name].node logging.vlog(4, "inp_node = \n%s", inp_node) @@ -480,6 +589,7 @@ def compute_receptive_field_from_graph_def(graph_def, input_node, output_node): raise ValueError("Output node was not found") if input_node not in rf_sizes_x: raise ValueError("Input node was not found") - return (rf_sizes_x[input_node], rf_sizes_y[input_node], - effective_strides_x[input_node], effective_strides_y[input_node], - effective_paddings_x[input_node], effective_paddings_y[input_node]) + return ReceptiveField( + (rf_sizes_x[input_node], rf_sizes_y[input_node]), + (effective_strides_x[input_node], effective_strides_y[input_node]), + (effective_paddings_x[input_node], effective_paddings_y[input_node])) diff --git a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py index 2771389250b1518f33ebadf3f1cfd23e653dab93..8d7d5440f630a3a78749e04a5eb058d637c258fc 100644 --- a/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py +++ b/tensorflow/contrib/receptive_field/python/util/receptive_field_test.py @@ -25,6 +25,7 @@ from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import nn from tensorflow.python.platform import test +import numpy as np def create_test_network_1(): @@ -150,6 +151,31 @@ def create_test_network_5(): return g +def create_test_network_6(): + """Aligned network with dropout for test. + + The graph is similar to create_test_network_1(), except that the right branch + has dropout normalization. + + Returns: + g: Tensorflow graph object (Graph proto). + """ + g = ops.Graph() + with g.as_default(): + # An 8x8 test image. + x = array_ops.placeholder(dtypes.float32, (1, 8, 8, 1), name='input_image') + # Left branch. + l1 = slim.conv2d(x, 1, [1, 1], stride=4, scope='L1', padding='VALID') + # Right branch. + l2_pad = array_ops.pad(x, [[0, 0], [1, 0], [1, 0], [0, 0]]) + l2 = slim.conv2d(l2_pad, 1, [3, 3], stride=2, scope='L2', padding='VALID') + l3 = slim.conv2d(l2, 1, [1, 1], stride=2, scope='L3', padding='VALID') + dropout = slim.dropout(l3) + # Addition. + nn.relu(l1 + dropout, name='output') + return g + + class RfUtilsTest(test.TestCase): def testComputeRFFromGraphDefAligned(self): @@ -220,6 +246,36 @@ class RfUtilsTest(test.TestCase): self.assertEqual(effective_padding_x, 0) self.assertEqual(effective_padding_y, 0) + def testComputeRFFromGraphDefStopPropagation(self): + graph_def = create_test_network_6().as_graph_def() + input_node = 'input_image' + output_node = 'output' + # Compute the receptive field but stop the propagation for the random + # uniform variable of the dropout. + (receptive_field_x, receptive_field_y, effective_stride_x, + effective_stride_y, effective_padding_x, effective_padding_y) = ( + receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node, + ['Dropout/dropout/random_uniform'])) + self.assertEqual(receptive_field_x, 3) + self.assertEqual(receptive_field_y, 3) + self.assertEqual(effective_stride_x, 4) + self.assertEqual(effective_stride_y, 4) + self.assertEqual(effective_padding_x, 1) + self.assertEqual(effective_padding_y, 1) + + def testComputeCoordinatesRoundtrip(self): + graph_def = create_test_network_1() + input_node = 'input_image' + output_node = 'output' + rf = receptive_field.compute_receptive_field_from_graph_def( + graph_def, input_node, output_node) + + x = np.random.randint(0, 100, (50, 2)) + y = rf.compute_feature_coordinates(x) + x2 = rf.compute_input_center_coordinates(y) + + self.assertAllEqual(x, x2) if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/reduce_slice_ops/BUILD b/tensorflow/contrib/reduce_slice_ops/BUILD index fded03090ea48ecea464d64ac87206700b6476c9..b31f4488f5882a0bc4e419668dba5da72d69b7fe 100644 --- a/tensorflow/contrib/reduce_slice_ops/BUILD +++ b/tensorflow/contrib/reduce_slice_ops/BUILD @@ -71,6 +71,7 @@ tf_custom_op_py_library( ":reduce_slice_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework", + "//tensorflow/python:platform", ], ) diff --git a/tensorflow/contrib/resampler/BUILD b/tensorflow/contrib/resampler/BUILD index 1b9efd1ecd7d4807fe04b52f2f4148e95fce9a8c..f0ecc8b85a5db93075d3cf0b55e7df95732bcf94 100644 --- a/tensorflow/contrib/resampler/BUILD +++ b/tensorflow/contrib/resampler/BUILD @@ -26,9 +26,15 @@ tf_custom_op_py_library( deps = [ ":resampler_ops", "//tensorflow/contrib/util:util_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:errors", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:util", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/resampler/kernels/resampler_ops.cc b/tensorflow/contrib/resampler/kernels/resampler_ops.cc index afc8bcd4462bbfd7c7f87480a795088ada35365f..7d9ef14cefc578e9401d95db9a625428cc0e2605 100644 --- a/tensorflow/contrib/resampler/kernels/resampler_ops.cc +++ b/tensorflow/contrib/resampler/kernels/resampler_ops.cc @@ -122,7 +122,7 @@ struct Resampler2DFunctor{ }; // Rough estimate of work for each batch entry. // From third_party/tensorflow/core/util/work_sharder.cc we gather that an - // estimate of the cost of each work unit is needed to correclty shard the + // estimate of the cost of each work unit is needed to correctly shard the // workload. Shard assumes each cost unit is 1ns, minimum cost per shard // being 10us. const int64 cost = static_cast(num_sampling_points) * diff --git a/tensorflow/contrib/rnn/BUILD b/tensorflow/contrib/rnn/BUILD index 3e6c09662fe8b54ff4c07175cbba99b87e27969c..b70a5bbcd107b4c21e09c6d01a2e461fa9edd250 100644 --- a/tensorflow/contrib/rnn/BUILD +++ b/tensorflow/contrib/rnn/BUILD @@ -24,6 +24,22 @@ load( "tf_kernel_tests_linkstatic", ) +cc_library( + name = "all_ops", + deps = [ + ":gru_ops_op_lib", + ":lstm_ops_op_lib", + ], +) + +cc_library( + name = "all_kernels", + deps = [ + ":gru_ops_kernels", + ":lstm_ops_kernels", + ], +) + tf_custom_op_py_library( name = "rnn_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]) + [ @@ -34,34 +50,36 @@ tf_custom_op_py_library( ":python/ops/_lstm_ops.so", ], kernels = [ - ":gru_ops_kernels", - ":lstm_ops_kernels", - ":gru_ops_op_lib", - ":lstm_ops_op_lib", + ":all_ops", + ":all_kernels", ], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ + ":benchmarking", ":gru_ops", ":lstm_ops", "//tensorflow/contrib/compiler:compiler_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/util:util_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:clip_ops", "//tensorflow/python:embedding_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", - "//tensorflow/python:layers", "//tensorflow/python:math_ops", "//tensorflow/python:nn_ops", - "//tensorflow/python:partitioned_variables", "//tensorflow/python:platform", + "//tensorflow/python:pywrap_tensorflow", + "//tensorflow/python:random_ops", "//tensorflow/python:rnn", "//tensorflow/python:rnn_cell", + "//tensorflow/python:session", + "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", ], ) @@ -125,6 +143,9 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = [ + "optonly", + ], ) cuda_py_tests( @@ -138,6 +159,7 @@ cuda_py_tests( "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", "//tensorflow/python:gradients", "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", @@ -147,6 +169,7 @@ cuda_py_tests( "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", + "//tensorflow/python/eager:context", ], shard_count = 10, ) @@ -259,6 +282,7 @@ cuda_py_tests( "//tensorflow/python:variable_scope", "//tensorflow/python:variables", ], + tags = ["no_oss"], ) tf_cc_test( @@ -361,7 +385,6 @@ py_binary( srcs_version = "PY2AND3", deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", "//tensorflow/python:pywrap_tensorflow", @@ -386,3 +409,10 @@ py_test( "//tensorflow/python:variables", ], ) + +py_library( + name = "benchmarking", + srcs = ["python/kernel_tests/benchmarking.py"], + srcs_version = "PY2AND3", + deps = ["//tensorflow/python:framework_ops"], +) diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.cc b/tensorflow/contrib/rnn/kernels/lstm_ops.cc index f74d6cec7625fda0febba4bdadc2e7e15f90edd6..941a457fd3ada312b981fb23c769ff9ecea9ff13 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.cc @@ -39,6 +39,195 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +namespace functor { + +template +void LSTMBlockCellFpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const CPUDevice& d, + const T forget_bias, const T cell_clip, bool use_peephole, + typename TTypes::ConstMatrix x, typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h) { + // Concat xh = [x, h]. + xh.slice(cell.xh_x_offsets(), cell.xh_x_extents()).device(d) = x; + xh.slice(cell.xh_h_offsets(), cell.xh_h_extents()).device(d) = h_prev; + + // states1 = xh * w + b + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + Eigen::array b_shape({1, b.dimensions()[0]}); + Eigen::array broadcast_shape({cell.batch_size(), 1}); + icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + + // Input gate. + if (use_peephole) { + auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); + i.device(d) = + (icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()) + i_peep) + .sigmoid(); + } else { + i.device(d) = + icfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).sigmoid(); + } + + // Cell input. + ci.device(d) = icfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).tanh(); + + // Forget gate (w/ bias). + if (use_peephole) { + auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias) + f_peep) + .sigmoid(); + } else { + f.device(d) = (icfo.slice(cell.icfo_f_offsets(), cell.cell_extents()) + + f.constant(forget_bias)) + .sigmoid(); + } + + // cs = ci .* i + f .* cs_prev + cs.device(d) = i * ci + f * cs_prev; + + if (cell_clip > 0.0f) { + cs.device(d) = + cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); + } + + // co = tanh(cs) + co.device(d) = cs.tanh(); + + // Output gate. + if (use_peephole) { + auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); + o.device(d) = + (icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()) + o_peep) + .sigmoid(); + } else { + o.device(d) = + icfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).sigmoid(); + } + + // h = o .* co + h.device(d) = o * co; +} + +template +void LSTMBlockCellBpropWithEigen( + const LSTMBlockCell& cell, OpKernelContext* ctx, const Device& d, + bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad) { + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; + + Eigen::array p_shape({1, cell.cell_size()}); + Eigen::array p_broadcast_shape({cell.batch_size(), 1}); + if (use_peephole) { + dcs.device(d) = + dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); + } + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; + + dicfo.slice(cell.icfo_i_offsets(), cell.cell_extents()).device(d) = di; + dicfo.slice(cell.icfo_c_offsets(), cell.cell_extents()).device(d) = dci; + dicfo.slice(cell.icfo_f_offsets(), cell.cell_extents()).device(d) = df; + dicfo.slice(cell.icfo_o_offsets(), cell.cell_extents()).device(d) = do_; + + cs_prev_grad.device(d) = dcs * f; + if (use_peephole) { + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + +#define DEFINE_CPU_SPECS(T) \ + template <> \ + void LSTMBlockCellFprop::operator()( \ + OpKernelContext* ctx, const CPUDevice& d, const T forget_bias, \ + const T cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ + typename TTypes::Matrix i, typename TTypes::Matrix cs, \ + typename TTypes::Matrix f, typename TTypes::Matrix o, \ + typename TTypes::Matrix ci, typename TTypes::Matrix co, \ + typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ + LSTMBlockCellFpropWithEigen( \ + *this, ctx, d, forget_bias, cell_clip, use_peephole, x, cs_prev, \ + h_prev, w, wci, wcf, wco, b, xh, i, cs, f, o, ci, co, icfo, h); \ + } \ + template <> \ + void LSTMBlockCellBprop::operator()( \ + OpKernelContext* ctx, const CPUDevice& d, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ + typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ + typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ + typename TTypes::ConstMatrix co, \ + typename TTypes::ConstMatrix cs_grad, \ + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ + typename TTypes::Matrix df, typename TTypes::Matrix di, \ + typename TTypes::Matrix dicfo, \ + typename TTypes::Matrix cs_prev_grad, \ + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ + typename TTypes::Vec wco_grad) { \ + LSTMBlockCellBpropWithEigen( \ + *this, ctx, d, use_peephole, x, cs_prev, h_prev, w, wci, wcf, wco, b, \ + i, cs, f, o, ci, co, cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, \ + cs_prev_grad, wci_grad, wcf_grad, wco_grad); \ + } \ + template struct LSTMBlockCellFprop; \ + template struct LSTMBlockCellBprop; + +DEFINE_CPU_SPECS(float); +#undef DEFINE_CPU_SPECS + +} // namespace functor + template class LSTMBlockCellOp : public OpKernel { public: @@ -495,7 +684,8 @@ namespace functor { typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ typename TTypes::Vec wco_grad); \ \ - extern template struct LSTMBlockCellBprop; + extern template struct LSTMBlockCellBprop; DECLARE_GPU_SPEC(float); // DECLARE_GPU_SPEC(double); diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops.h b/tensorflow/contrib/rnn/kernels/lstm_ops.h index 6317f32ac3b72d9fadf3c410de0f1df6539bc501..1906581b16b2e76243320bc67c8ac831323fb8e7 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops.h +++ b/tensorflow/contrib/rnn/kernels/lstm_ops.h @@ -99,6 +99,12 @@ struct LSTMBlockCell { input_size_(input_size), cell_size_(cell_size) {} + int batch_size() const { return batch_size_; } + + int input_size() const { return input_size_; } + + int cell_size() const { return cell_size_; } + inline Eigen::array icfo_i_offsets() const { return {0, 0}; } @@ -141,6 +147,8 @@ struct LSTMBlockCell { const int cell_size_; }; +// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for +// GPUDevice implementation. template struct LSTMBlockCellFprop : public LSTMBlockCell { LSTMBlockCellFprop(const int batch_size, const int input_size, @@ -158,71 +166,11 @@ struct LSTMBlockCellFprop : public LSTMBlockCell { typename TTypes::Matrix cs, typename TTypes::Matrix f, typename TTypes::Matrix o, typename TTypes::Matrix ci, typename TTypes::Matrix co, typename TTypes::Matrix icfo, - typename TTypes::Matrix h) { - // Concat xh = [x, h]. - xh.slice(xh_x_offsets(), xh_x_extents()).device(d) = x; - xh.slice(xh_h_offsets(), xh_h_extents()).device(d) = h_prev; - - // states1 = xh * w + b - typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); - TensorBlasGemm::compute(ctx, d, false, false, T(1), - const_xh, w, T(0), icfo); - Eigen::array b_shape({1, b.dimensions()[0]}); - Eigen::array broadcast_shape({batch_size_, 1}); - icfo.device(d) += b.reshape(b_shape).broadcast(broadcast_shape); - - Eigen::array p_shape({1, cell_size_}); - Eigen::array p_broadcast_shape({batch_size_, 1}); - - // Input gate. - if (use_peephole) { - auto i_peep = cs_prev * wci.reshape(p_shape).broadcast(p_broadcast_shape); - i.device(d) = - (icfo.slice(icfo_i_offsets(), cell_extents()) + i_peep).sigmoid(); - } else { - i.device(d) = icfo.slice(icfo_i_offsets(), cell_extents()).sigmoid(); - } - - // Cell input. - ci.device(d) = icfo.slice(icfo_c_offsets(), cell_extents()).tanh(); - - // Forget gate (w/ bias). - if (use_peephole) { - auto f_peep = cs_prev * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) + - f.constant(forget_bias) + f_peep) - .sigmoid(); - } else { - f.device(d) = (icfo.slice(icfo_f_offsets(), cell_extents()) + - f.constant(forget_bias)) - .sigmoid(); - } - - // cs = ci .* i + f .* cs_prev - cs.device(d) = i * ci + f * cs_prev; - - if (cell_clip > 0.0f) { - cs.device(d) = - cs.binaryExpr(cs.constant(cell_clip), Eigen::scalar_clip_op()); - } - - // co = tanh(cs) - co.device(d) = cs.tanh(); - - // Output gate. - if (use_peephole) { - auto o_peep = cs * wco.reshape(p_shape).broadcast(p_broadcast_shape); - o.device(d) = - (icfo.slice(icfo_o_offsets(), cell_extents()) + o_peep).sigmoid(); - } else { - o.device(d) = icfo.slice(icfo_o_offsets(), cell_extents()).sigmoid(); - } - - // h = o .* co - h.device(d) = o * co; - } + typename TTypes::Matrix h); }; +// See lstm_ops.cc for CPUDevice implementation and lstm_ops_gpu.cu.cc for +// GPUDevice implementation. template struct LSTMBlockCellBprop : public LSTMBlockCell { LSTMBlockCellBprop(const int batch_size, const int input_size, @@ -245,45 +193,7 @@ struct LSTMBlockCellBprop : public LSTMBlockCell { typename TTypes::Matrix df, typename TTypes::Matrix di, typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, - typename TTypes::Vec wco_grad) { - // do[t] = sigm'(o[t]) .* dh[t] .* co[t] - do_.device(d) = o * (o.constant(T(1)) - o) * h_grad * co; - - // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] - dcs.device(d) = (co.constant(T(1)) - co * co) * h_grad * o + cs_grad; - - Eigen::array p_shape({1, cell_size_}); - Eigen::array p_broadcast_shape({batch_size_, 1}); - if (use_peephole) { - dcs.device(d) = - dcs + do_ * wco.reshape(p_shape).broadcast(p_broadcast_shape); - } - - // dci[t] = tanh'(ci[t]) dcs[t] i[t] - dci.device(d) = (ci.constant(T(1)) - ci * ci) * dcs * i; - - // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] - df.device(d) = f * (f.constant(T(1)) - f) * dcs * cs_prev; - - // di[t] = sigm'(i[t]) dcs[t] ci[t] - di.device(d) = i * (i.constant(T(1)) - i) * dcs * ci; - - dicfo.slice(icfo_i_offsets(), cell_extents()).device(d) = di; - dicfo.slice(icfo_c_offsets(), cell_extents()).device(d) = dci; - dicfo.slice(icfo_f_offsets(), cell_extents()).device(d) = df; - dicfo.slice(icfo_o_offsets(), cell_extents()).device(d) = do_; - - cs_prev_grad.device(d) = dcs * f; - if (use_peephole) { - cs_prev_grad.device(d) = - cs_prev_grad + - di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + - df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); - wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); - wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); - wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); - } - } + typename TTypes::Vec wco_grad); }; template diff --git a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc index b33ca5fc8d03684812319596fb66cec4d7eec744..6d3758fef15e7130b740a377d8bcd41d31203299 100644 --- a/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc +++ b/tensorflow/contrib/rnn/kernels/lstm_ops_gpu.cu.cc @@ -19,21 +19,388 @@ limitations under the License. #include "tensorflow/contrib/rnn/kernels/lstm_ops.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" +#include "tensorflow/core/kernels/eigen_activations.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/util/cuda_kernel_helper.h" + namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -#define DEFINE_GPU_SPECS(T) \ - template struct TensorZero; \ - template struct TensorUnalignedZero; \ - template struct TensorCopy; \ - template struct TensorCopyUnaligned; \ - template struct TensorCopyToUnaligned; \ - template struct TensorAdd; \ - template struct LSTMBlockCellFprop; \ - template struct LSTMBlockCellBprop; \ - template struct BlockLSTMBprop; +namespace { + +// Adds bias, applies non-linearities and gates. +// +// Launch with a 2D setup such that there is one thread per (example, +// activation) with 'x' governing example index and 'y' governing activation. +// +// Launch with blocks of (batch x 32) +// +// TODO(b/67600500): Try making 'use_peephole' a template parameter. +template +__global__ void lstm_gates(const T* icfo, const T* b, const T* cs_prev, + const T* wci, const T* wcf, const T* wco, T* o, T* h, + T* ci, T* cs, T* co, T* i, T* f, const T forget_bias, + const T cell_clip, const int batch_size, + const int cell_size) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + // The following code assumes the input arrays are of the following + // shapes and interpretations. + // + // 1) 'icfo' is a matrix such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | | | | | + // | i | c | f | o | batch_size + // | | | | | + // +----------+----------+----------+----------+ + // + // 'gid' is the index assigned to this thread for 'icfo' in the 'i' submatrix. + // + // 2) 'b' is a vector such that, + // + // cell_size cell_size cell_size cell_size + // +----------+----------+----------+----------+ + // | i | c | f | o | 1 + // +----------+----------+----------+----------+ + // + // 'act_id' is the index assigned to this thread for 'b' in the 'i' subvector. + // + // 3) 'wc{i,f,o}' are vectors such that, + // + // cell_size + // +----------+ + // | i | 1 + // +----------+ + // + // 'act_id' is the index to this thread. + // + // 4) All other matrices have the form, + // + // cell_size + // +----------+ + // | | + // | i | batch_size + // | | + // +----------+ + // + // 'cid' is the index assigned to this thread. + // + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + Eigen::internal::scalar_sigmoid_op sigmoid_op; + Eigen::internal::scalar_tanh_op tanh_op; + Eigen::scalar_clip_op clip_op; + + T i_local; + if (use_peephole) { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id] + + cs_prev[cid] * wci[act_id]); + } else { + i_local = sigmoid_op(icfo[0 * cell_size + gid] + b[0 * cell_size + act_id]); + } + i[cid] = i_local; + + const T ci_local = + tanh_op(icfo[1 * cell_size + gid] + b[1 * cell_size + act_id]); + ci[cid] = ci_local; + + T f_local; + if (use_peephole) { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias + cs_prev[cid] * wcf[act_id]); + } else { + f_local = sigmoid_op(icfo[2 * cell_size + gid] + b[2 * cell_size + act_id] + + forget_bias); + } + f[cid] = f_local; + + T cs_local = i_local * ci_local + f_local * cs_prev[cid]; + if (cell_clip > 0.0) { + cs_local = clip_op(cs_local, cell_clip); + } + cs[cid] = cs_local; + + const T co_local = tanh_op(cs_local); + co[cid] = co_local; + + T o_local; + if (use_peephole) { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id] + + cs_local * wco[act_id]); + } else { + o_local = sigmoid_op(icfo[3 * cell_size + gid] + b[3 * cell_size + act_id]); + } + o[cid] = o_local; + + h[cid] = o_local * co_local; +} + +// Concatenate 'x' and 'h' and copy their contents into 'xh'. +template +__global__ void concat_xh(T* xh, const T* x, const T* h_prev, + const int batch_size, const int cell_size, + const int input_size) { + // Assumes 'x', 'h', and 'xh' are of the following shape, + // + // input_size cell_size + // +----------+----------+ + // | | | + // | x | h | batch_size + // | | | + // +----------+----------+ + // + const int gid = blockDim.x * blockIdx.x + threadIdx.x; + const int width = input_size + cell_size; + + if (gid >= width * batch_size) return; + + const int output_row = gid / width; + const int output_col = gid % width; + + if (output_col < input_size) { // x + xh[gid] = x[output_row * input_size + output_col]; + } else { // h + xh[gid] = h_prev[output_row * cell_size + output_col - input_size]; + } +} + +template +void LSTMBlockCellFpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, + const T cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::Matrix xh, typename TTypes::Matrix i, + typename TTypes::Matrix cs, typename TTypes::Matrix f, + typename TTypes::Matrix o, typename TTypes::Matrix ci, + typename TTypes::Matrix co, typename TTypes::Matrix icfo, + typename TTypes::Matrix h, int batch_size, int cell_size, + int input_size) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + // Concatenate xh = [x, h]. + // + // Each block is assigned 128 threads. Good values are in [128, 1024] and are + // divisible by 32 (the size of a warp). The number of blocks is such that + // there are enough to process all the data. + const int block_dim = 128; + const int grid_dim = + Eigen::divup(batch_size * (cell_size + input_size), block_dim); + concat_xh<<>>( + xh.data(), x.data(), h_prev.data(), batch_size, cell_size, input_size); + + // states1 = xh * w + typename TTypes::ConstMatrix const_xh(xh.data(), xh.dimensions()); + TensorBlasGemm::compute( + ctx, d, false, false, T(1), const_xh, w, T(0), icfo); + + // Add bias, apply non-linearities and gating. + // + // Use 2D blocks. The number of threads per block is equal to x * y, where x = + // min(batch_size, 8) and y = 32. See above for guidance on number of + // threads. + dim3 block_dim_2d(std::min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + if (use_peephole) { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } else { + lstm_gates<<>>( + icfo.data(), b.data(), cs_prev.data(), wci.data(), wcf.data(), + wco.data(), o.data(), h.data(), ci.data(), cs.data(), co.data(), + i.data(), f.data(), forget_bias, cell_clip, batch_size, cell_size); + } +} + +template +__global__ void lstm_gates_bprop( + const T* cs_prev, // [batch_size, cell_size] + const T* h_prev, // [batch_size, cell_size] + const T* w, // [input_size + cell_size, 4 * cell_size] + const T* wci, // [cell_size] + const T* wcf, // [cell_size] + const T* wco, // [cell_size] + const T* b, // [4 * cell_size] + const T* i, // [batch_size, cell_size] + const T* cs, // [batch_size, cell_size] + const T* f, // [batch_size, cell_size] + const T* o, // [batch_size, cell_size] + const T* ci, // [batch_size, cell_size] + const T* co, // [batch_size, cell_size] + const T* cs_grad, // [batch_size, cell_size] + const T* h_grad, // [batch_size, cell_size] + T* do_, // [batch_size, cell_size] + T* dcs, // [batch_size, cell_size] + T* dci, // [batch_size, cell_size] + T* df, // [batch_size, cell_size] + T* di, // [batch_size, cell_size] + T* dicfo, // [input_size + cell_size, 4 * cell_size] + T* cs_prev_grad, // [batch_size, cell_size] + const int batch_size, const int cell_size, const bool use_peephole) { + const int batch_id = blockIdx.x * blockDim.x + threadIdx.x; + const int act_id = blockIdx.y * blockDim.y + threadIdx.y; + + if (batch_id >= batch_size || act_id >= cell_size) return; + + const int gid = batch_id * cell_size * 4 + act_id; + const int cid = batch_id * cell_size + act_id; + + const T one = static_cast(1.0f); + + // do[t] = sigm'(o[t]) .* dh[t] .* co[t] + const T o_local = o[cid]; + const T h_grad_local = h_grad[cid]; + const T co_local = co[cid]; + const T ci_local = ci[cid]; + const T do_local = o_local * (one - o_local) * h_grad_local * co_local; + const T i_local = i[cid]; + const T f_local = f[cid]; + + do_[cid] = do_local; + + // dcs[t] += tanh'(cs[t]) .* dh[t] .* o[t] + dcs[t + 1] .* f[t + 1] + T dcs_local = + (one - co_local * co_local) * h_grad_local * o_local + cs_grad[cid]; + if (use_peephole) { + dcs_local += do_local * wco[act_id]; + } + dcs[cid] = dcs_local; + + // dci[t] = tanh'(ci[t]) dcs[t] i[t] + const T dci_local = (one - ci_local * ci_local) * dcs_local * i_local; + dci[cid] = dci_local; + + // df[t] = sigm'(f[t]) dcs[t] cs[t - 1] + const T df_local = f_local * (one - f_local) * dcs_local * cs_prev[cid]; + df[cid] = df_local; + + // di[t] = sigm'(i[t]) dcs[t] ci[t] + const T di_local = i_local * (one - i_local) * dcs_local * ci_local; + di[cid] = di_local; + + dicfo[gid + 0 * cell_size] = di_local; + dicfo[gid + 1 * cell_size] = dci_local; + dicfo[gid + 2 * cell_size] = df_local; + dicfo[gid + 3 * cell_size] = do_local; + + cs_prev_grad[cid] = dcs_local * f_local; + if (use_peephole) { + cs_prev_grad[cid] += di_local * wci[act_id] + df_local * wcf[act_id]; + } +} + +template +void LSTMBlockCellBpropWithCUDA( + OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstMatrix x, + typename TTypes::ConstMatrix cs_prev, + typename TTypes::ConstMatrix h_prev, typename TTypes::ConstMatrix w, + typename TTypes::ConstVec wci, typename TTypes::ConstVec wcf, + typename TTypes::ConstVec wco, typename TTypes::ConstVec b, + typename TTypes::ConstMatrix i, typename TTypes::ConstMatrix cs, + typename TTypes::ConstMatrix f, typename TTypes::ConstMatrix o, + typename TTypes::ConstMatrix ci, typename TTypes::ConstMatrix co, + typename TTypes::ConstMatrix cs_grad, + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, + typename TTypes::Matrix df, typename TTypes::Matrix di, + typename TTypes::Matrix dicfo, typename TTypes::Matrix cs_prev_grad, + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, + typename TTypes::Vec wco_grad, const int batch_size, const int cell_size, + const bool use_peephole) { + const cudaStream_t& cu_stream = GetCudaStream(ctx); + + dim3 block_dim_2d(std::min(batch_size, 8), 32); + dim3 grid_dim_2d(Eigen::divup(batch_size, static_cast(block_dim_2d.x)), + Eigen::divup(cell_size, static_cast(block_dim_2d.y))); + + lstm_gates_bprop<<>>( + cs_prev.data(), h_prev.data(), w.data(), wci.data(), wcf.data(), + wco.data(), b.data(), i.data(), cs.data(), f.data(), o.data(), ci.data(), + co.data(), cs_grad.data(), h_grad.data(), do_.data(), dcs.data(), + dci.data(), df.data(), di.data(), dicfo.data(), cs_prev_grad.data(), + batch_size, cell_size, use_peephole); + + if (use_peephole) { + Eigen::array p_shape({1, cell_size}); + Eigen::array p_broadcast_shape({batch_size, 1}); + cs_prev_grad.device(d) = + cs_prev_grad + di * wci.reshape(p_shape).broadcast(p_broadcast_shape) + + df * wcf.reshape(p_shape).broadcast(p_broadcast_shape); + wci_grad.device(d) = (di * cs_prev).sum(Eigen::array({0})); + wcf_grad.device(d) = (df * cs_prev).sum(Eigen::array({0})); + wco_grad.device(d) = (do_ * cs).sum(Eigen::array({0})); + } +} + +} // namespace + +#define DEFINE_GPU_SPECS(T) \ + template struct TensorZero; \ + template struct TensorUnalignedZero; \ + template struct TensorCopy; \ + template struct TensorCopyUnaligned; \ + template struct TensorCopyToUnaligned; \ + template struct TensorAdd; \ + template <> \ + void LSTMBlockCellFprop::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, const T forget_bias, \ + const T cell_clip, bool use_peephole, typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::Matrix xh, \ + typename TTypes::Matrix i, typename TTypes::Matrix cs, \ + typename TTypes::Matrix f, typename TTypes::Matrix o, \ + typename TTypes::Matrix ci, typename TTypes::Matrix co, \ + typename TTypes::Matrix icfo, typename TTypes::Matrix h) { \ + LSTMBlockCellFpropWithCUDA(ctx, d, forget_bias, cell_clip, use_peephole, \ + x, cs_prev, h_prev, w, wci, wcf, wco, b, xh, i, \ + cs, f, o, ci, co, icfo, h, batch_size_, \ + cell_size_, input_size_); \ + } \ + template <> \ + void LSTMBlockCellBprop::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, bool use_peephole, \ + typename TTypes::ConstMatrix x, \ + typename TTypes::ConstMatrix cs_prev, \ + typename TTypes::ConstMatrix h_prev, \ + typename TTypes::ConstMatrix w, typename TTypes::ConstVec wci, \ + typename TTypes::ConstVec wcf, typename TTypes::ConstVec wco, \ + typename TTypes::ConstVec b, typename TTypes::ConstMatrix i, \ + typename TTypes::ConstMatrix cs, typename TTypes::ConstMatrix f, \ + typename TTypes::ConstMatrix o, typename TTypes::ConstMatrix ci, \ + typename TTypes::ConstMatrix co, \ + typename TTypes::ConstMatrix cs_grad, \ + typename TTypes::ConstMatrix h_grad, typename TTypes::Matrix do_, \ + typename TTypes::Matrix dcs, typename TTypes::Matrix dci, \ + typename TTypes::Matrix df, typename TTypes::Matrix di, \ + typename TTypes::Matrix dicfo, \ + typename TTypes::Matrix cs_prev_grad, \ + typename TTypes::Vec wci_grad, typename TTypes::Vec wcf_grad, \ + typename TTypes::Vec wco_grad) { \ + LSTMBlockCellBpropWithCUDA( \ + ctx, d, x, cs_prev, h_prev, w, wci, wcf, wco, b, i, cs, f, o, ci, co, \ + cs_grad, h_grad, do_, dcs, dci, df, di, dicfo, cs_prev_grad, wci_grad, \ + wcf_grad, wco_grad, batch_size_, cell_size_, use_peephole); \ + } \ + template struct LSTMBlockCellFprop; \ + template struct LSTMBlockCellBprop; \ + template struct BlockLSTMBprop; DEFINE_GPU_SPECS(float); // DEFINE_GPU_SPECS(double); diff --git a/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py new file mode 100644 index 0000000000000000000000000000000000000000..a48cd58706e72516f18098e643c0fa867d33beb2 --- /dev/null +++ b/tensorflow/contrib/rnn/python/kernel_tests/benchmarking.py @@ -0,0 +1,66 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Library for benchmarking OpKernels.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import itertools +import time + +from tensorflow.python.framework import ops + + +def device(use_gpu=False): + """TensorFlow device to assign ops to.""" + if use_gpu: + return ops.device("/gpu:0") + return ops.device("/cpu:0") + + +def seconds_per_run(op, sess, num_runs=50): + """Number of seconds taken to execute 'op' once on average.""" + for _ in range(2): + sess.run(op) + + start_time = time.time() + for _ in range(num_runs): + sess.run(op) + + end_time = time.time() + time_taken = (end_time - start_time) / num_runs + return time_taken + + +def dict_product(dicts): + """Constructs iterator over outer product of entries in a dict-of-lists. + + Example: + >>> dict_products({"a": [1,2], "b": [3, 4]}) + >>> [{"a": 1, "b": 3}, + {"a": 1, "b": 4}, + {"a": 2, "b": 3}, + {"a": 2, "b": 4}] + + Args: + dicts: dictionary with string keys and list values. + + Yields: + Individual dicts from outer product. + """ + keys, values = zip(*dicts.items()) + for config_values in itertools.product(*values): + yield dict(zip(keys, config_values)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py index f222c4745c13dc4b07fa5afa61fef5615bf0dba8..16b6d145e3fd3e4e5bb34481cc61eb5706cf1772 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_cell_test.py @@ -22,10 +22,8 @@ import functools import numpy as np -# TODO(ebrevdo): Remove once _linear is fully deprecated. -# pylint: disable=protected-access - from tensorflow.contrib import rnn as contrib_rnn +from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.core.protobuf import config_pb2 from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes @@ -41,10 +39,12 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.platform import test from tensorflow.python.framework import test_util +from tensorflow.contrib.rnn.python.ops import rnn_cell as contrib_rnn_cell + # pylint: enable=protected-access -linear = rnn_cell_impl._linear +Linear = core_rnn_cell._Linear # pylint: disable=invalid-name class RNNCellTest(test.TestCase): @@ -54,20 +54,20 @@ class RNNCellTest(test.TestCase): with variable_scope.variable_scope( "root", initializer=init_ops.constant_initializer(1.0)): x = array_ops.zeros([1, 2]) - l = linear([x], 2, False) + l = Linear([x], 2, False)([x]) sess.run([variables_lib.global_variables_initializer()]) res = sess.run([l], {x.name: np.array([[1., 2.]])}) self.assertAllClose(res[0], [[3.0, 3.0]]) # Checks prevent you from accidentally creating a shared function. with self.assertRaises(ValueError): - l1 = linear([x], 2, False) + l1 = Linear([x], 2, False)([x]) # But you can create a new one in a new scope and share the variables. with variable_scope.variable_scope("l1") as new_scope: - l1 = linear([x], 2, False) + l1 = Linear([x], 2, False)([x]) with variable_scope.variable_scope(new_scope, reuse=True): - linear([l1], 2, False) + Linear([l1], 2, False)([l1]) self.assertEqual(len(variables_lib.trainable_variables()), 2) def testBasicRNNCell(self): @@ -128,8 +128,8 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.175991, 0.175991]]) with variable_scope.variable_scope( "other", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros( - [1, 3]) # Test GRUCell with input_size != num_units. + # Test GRUCell with input_size != num_units. + x = array_ops.zeros([1, 3]) m = array_ops.zeros([1, 2]) g, _ = rnn_cell_impl.GRUCell(2)(x, m) sess.run([variables_lib.global_variables_initializer()]) @@ -141,58 +141,67 @@ class RNNCellTest(test.TestCase): self.assertAllClose(res[0], [[0.156736, 0.156736]]) def testBasicLSTMCell(self): - with self.test_session() as sess: - with variable_scope.variable_scope( - "root", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros([1, 2]) - m = array_ops.zeros([1, 8]) - cell = rnn_cell_impl.MultiRNNCell( - [ - rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False) - for _ in range(2) - ], - state_is_tuple=False) - g, out_m = cell(x, m) - expected_variable_names = [ - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._WEIGHTS_VARIABLE_NAME, - "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % - rnn_cell_impl._BIAS_VARIABLE_NAME - ] - self.assertEqual( - expected_variable_names, [v.name for v in cell.trainable_variables]) - self.assertFalse(cell.non_trainable_variables) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1.]]), - m.name: 0.1 * np.ones([1, 8])}) - self.assertEqual(len(res), 2) - variables = variables_lib.global_variables() - self.assertEqual(expected_variable_names, [v.name for v in variables]) - # The numbers in results were not calculated, this is just a smoke test. - self.assertAllClose(res[0], [[0.24024698, 0.24024698]]) - expected_mem = np.array([[ - 0.68967271, 0.68967271, 0.44848421, 0.44848421, 0.39897051, - 0.39897051, 0.24024698, 0.24024698 - ]]) - self.assertAllClose(res[1], expected_mem) - with variable_scope.variable_scope( - "other", initializer=init_ops.constant_initializer(0.5)): - x = array_ops.zeros( - [1, 3]) # Test BasicLSTMCell with input_size != num_units. - m = array_ops.zeros([1, 4]) - g, out_m = rnn_cell_impl.BasicLSTMCell(2, state_is_tuple=False)(x, m) - sess.run([variables_lib.global_variables_initializer()]) - res = sess.run( - [g, out_m], - {x.name: np.array([[1., 1., 1.]]), - m.name: 0.1 * np.ones([1, 4])}) - self.assertEqual(len(res), 2) + for dtype in [dtypes.float16, dtypes.float32]: + np_dtype = dtype.as_numpy_dtype + with self.test_session(graph=ops.Graph()) as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2], dtype=dtype) + m = array_ops.zeros([1, 8], dtype=dtype) + cell = rnn_cell_impl.MultiRNNCell( + [ + rnn_cell_impl.BasicLSTMCell( + 2, state_is_tuple=False) + for _ in range(2) + ], + state_is_tuple=False) + self.assertEqual(cell.dtype, None) + g, out_m = cell(x, m) + # Layer infers the input type. + self.assertEqual(cell.dtype, dtype.name) + expected_variable_names = [ + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_0/basic_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % + rnn_cell_impl._WEIGHTS_VARIABLE_NAME, + "root/multi_rnn_cell/cell_1/basic_lstm_cell/%s:0" % + rnn_cell_impl._BIAS_VARIABLE_NAME + ] + self.assertEqual( + expected_variable_names, + [v.name for v in cell.trainable_variables]) + self.assertFalse(cell.non_trainable_variables) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_m], + {x.name: np.array([[1., 1.]]), + m.name: 0.1 * np.ones([1, 8])}) + self.assertEqual(len(res), 2) + variables = variables_lib.global_variables() + self.assertEqual(expected_variable_names, [v.name for v in variables]) + # The numbers in results were not calculated, this is just a + # smoke test. + self.assertAllClose( + res[0], np.array([[0.240, 0.240]], dtype=np_dtype), 1e-2) + expected_mem = np.array( + [[0.689, 0.689, 0.448, 0.448, 0.398, 0.398, 0.240, 0.240]], + dtype=np_dtype) + self.assertAllClose(res[1], expected_mem, 1e-2) + with variable_scope.variable_scope( + "other", initializer=init_ops.constant_initializer(0.5)): + # Test BasicLSTMCell with input_size != num_units. + x = array_ops.zeros([1, 3], dtype=dtype) + m = array_ops.zeros([1, 4], dtype=dtype) + g, out_m = rnn_cell_impl.BasicLSTMCell( + 2, state_is_tuple=False)(x, m) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run( + [g, out_m], + {x.name: np.array([[1., 1., 1.]], dtype=np_dtype), + m.name: 0.1 * np.ones([1, 4], dtype=np_dtype)}) + self.assertEqual(len(res), 2) def testBasicLSTMCellDimension0Error(self): """Tests that dimension 0 in both(x and m) shape must be equal.""" @@ -352,6 +361,45 @@ class RNNCellTest(test.TestCase): self.assertEquals(variables[2].op.name, "root/lstm_cell/projection/kernel") + def testLSTMCellLayerNorm(self): + with self.test_session() as sess: + num_units = 2 + num_proj = 3 + batch_size = 1 + input_size = 4 + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([batch_size, input_size]) + c = array_ops.zeros([batch_size, num_units]) + h = array_ops.zeros([batch_size, num_proj]) + state = rnn_cell_impl.LSTMStateTuple(c, h) + cell = contrib_rnn_cell.LayerNormLSTMCell( + num_units=num_units, + num_proj=num_proj, + forget_bias=1.0, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0) + g, out_m = cell(x, state) + sess.run([variables_lib.global_variables_initializer()]) + res = sess.run([g, out_m], { + x.name: np.ones((batch_size, input_size)), + c.name: 0.1 * np.ones((batch_size, num_units)), + h.name: 0.1 * np.ones((batch_size, num_proj)) + }) + self.assertEqual(len(res), 2) + # The numbers in results were not calculated, this is mostly just a + # smoke test. + self.assertEqual(res[0].shape, (batch_size, num_proj)) + self.assertEqual(res[1][0].shape, (batch_size, num_units)) + self.assertEqual(res[1][1].shape, (batch_size, num_proj)) + # Different inputs so different outputs and states + for i in range(1, batch_size): + self.assertTrue( + float(np.linalg.norm((res[0][0, :] - res[0][i, :]))) < 1e-6) + self.assertTrue( + float(np.linalg.norm((res[1][0, :] - res[1][i, :]))) < 1e-6) + def testOutputProjectionWrapper(self): with self.test_session() as sess: with variable_scope.variable_scope( @@ -441,6 +489,17 @@ class RNNCellTest(test.TestCase): outputs, _ = cell(x, m) self.assertTrue("cpu:14159" in outputs.device.lower()) + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testDeviceWrapperDynamicExecutionNodesAreAllProperlyLocated(self): if not test.is_gpu_available(): # Can't perform this test w/o a GPU @@ -462,10 +521,7 @@ class RNNCellTest(test.TestCase): sess.run([variables_lib.global_variables_initializer()]) _ = sess.run(outputs, options=opts, run_metadata=run_metadata) - step_stats = run_metadata.step_stats - ix = 0 if gpu_dev in step_stats.dev_stats[0].device else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) self.assertFalse([s for s in cpu_stats if "gru_cell" in s.node_name]) self.assertTrue([s for s in gpu_stats if "gru_cell" in s.node_name]) @@ -829,7 +885,8 @@ def basic_rnn_cell(inputs, state, num_units, scope=None): else: with variable_scope.variable_scope(scope, "basic_rnn_cell", [inputs, state]): - output = math_ops.tanh(linear([inputs, state], num_units, True)) + output = math_ops.tanh( + Linear([inputs, state], num_units, True)([inputs, state])) return output, output diff --git a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py index 40a3fb2fb0b174681252265d593de2935ee2efa2..9cea2ec79a982e4fb362ec564eb72b3894917842 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/core_rnn_test.py @@ -25,10 +25,12 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.contrib import rnn as rnn_lib from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops as ops_lib from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients_impl @@ -167,7 +169,7 @@ class RNNTest(test.TestCase): self.assertEqual(out.get_shape(), inp.get_shape()) self.assertEqual(out.dtype, inp.dtype) - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) @@ -202,7 +204,7 @@ class RNNTest(test.TestCase): self.assertEqual(out.get_shape().as_list(), inp.get_shape().as_list()) self.assertEqual(out.dtype, inp.dtype) - with self.test_session(use_gpu=False) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value}) full_dropout_values = sess.run(dropped_outputs, @@ -213,7 +215,7 @@ class RNNTest(test.TestCase): for d_v in full_dropout_values[:-1]: # Add 1.0 to dropped_out (all zeros) self.assertAllClose(d_v, np.ones_like(input_value)) - def _testDynamicCalculation(self, use_gpu): + def testDynamicCalculation(self): cell = Plus1RNNCell() sequence_length = array_ops.placeholder(dtypes.int64) batch_size = 2 @@ -228,7 +230,7 @@ class RNNTest(test.TestCase): cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) self.assertEqual(len(dynamic_outputs), len(inputs)) - with self.test_session(use_gpu=use_gpu) as sess: + with self.test_session(use_gpu=True) as sess: input_value = np.random.randn(batch_size, input_size) dynamic_values = sess.run( dynamic_outputs, @@ -259,10 +261,6 @@ class RNNTest(test.TestCase): np.vstack((1.0 * (1 + 1) * np.ones((input_size)), 1.0 * (2 + 1) * np.ones((input_size))))) - def testDynamicCalculation(self): - self._testDynamicCalculation(True) - self._testDynamicCalculation(False) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: @@ -307,12 +305,12 @@ class LSTMTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testNoProjNoSharding(self, use_gpu): + def testNoProjNoSharding(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) cell = rnn_cell.LSTMCell( @@ -330,12 +328,12 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def _testCellClipping(self, use_gpu): + def testCellClipping(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) cell = rnn_cell.LSTMCell( @@ -361,12 +359,12 @@ class LSTMTest(test.TestCase): # if cell c is clipped to 0, tanh(c) = 0 => m==0 self.assertAllEqual(value, np.zeros((batch_size, num_units))) - def _testNoProjNoShardingSimpleStateSaver(self, use_gpu): + def testNoProjNoShardingSimpleStateSaver(self): num_units = 3 input_size = 5 batch_size = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) state_saver = TestStateSaver(batch_size, 2 * num_units) @@ -491,13 +489,13 @@ class LSTMTest(test.TestCase): self.assertAllEqual(last_states[i], named_saved_states[flat_state_names[i]]) - def _testProjNoSharding(self, use_gpu): + def testProjNoSharding(self): num_units = 3 input_size = 5 batch_size = 2 num_proj = 4 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) inputs = max_length * [ @@ -582,7 +580,7 @@ class LSTMTest(test.TestCase): state_tuple_v = sess.run(state_tuple, feed_dict={inputs[0]: input_value}) self.assertAllEqual(state_notuple_v, np.hstack(state_tuple_v)) - def _testProjSharding(self, use_gpu): + def testProjSharding(self): num_units = 3 input_size = 5 batch_size = 2 @@ -590,7 +588,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -616,7 +614,7 @@ class LSTMTest(test.TestCase): input_value = np.random.randn(batch_size, input_size) sess.run(outputs, feed_dict={inputs[0]: input_value}) - def _testDoubleInput(self, use_gpu): + def testDoubleInput(self): num_units = 3 input_size = 5 batch_size = 2 @@ -624,7 +622,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer(-1, 1, seed=self._seed) inputs = max_length * [ array_ops.placeholder( @@ -653,7 +651,7 @@ class LSTMTest(test.TestCase): values = sess.run(outputs, feed_dict={inputs[0]: input_value}) self.assertEqual(values[0].dtype, input_value.dtype) - def _testShardNoShardEquivalentOutput(self, use_gpu): + def testShardNoShardEquivalentOutput(self): num_units = 3 input_size = 5 batch_size = 2 @@ -661,7 +659,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: inputs = max_length * [ array_ops.placeholder( dtypes.float32, shape=(None, input_size)) @@ -708,7 +706,7 @@ class LSTMTest(test.TestCase): for (s_noshard, s_shard) in zip(state_values_noshard, state_values_shard): self.assertAllClose(s_noshard, s_shard, atol=1e-3) - def _testDoubleInputWithDropoutAndDynamicCalculation(self, use_gpu): + def testDoubleInputWithDropoutAndDynamicCalculation(self): """Smoke test for using LSTM with doubles, dropout, dynamic calculation.""" num_units = 3 @@ -718,7 +716,7 @@ class LSTMTest(test.TestCase): num_proj_shards = 3 num_unit_shards = 2 max_length = 8 - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: sequence_length = array_ops.placeholder(dtypes.int64) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) @@ -843,44 +841,13 @@ class LSTMTest(test.TestCase): for out0, out1 in zip(outputs0_values, outputs1_values): self.assertAllEqual(out0, out1) - def testNoProjNoShardingSimpleStateSaver(self): - self._testNoProjNoShardingSimpleStateSaver(use_gpu=False) - self._testNoProjNoShardingSimpleStateSaver(use_gpu=True) - - def testNoProjNoSharding(self): - self._testNoProjNoSharding(use_gpu=False) - self._testNoProjNoSharding(use_gpu=True) - - def testCellClipping(self): - self._testCellClipping(use_gpu=False) - self._testCellClipping(use_gpu=True) - - def testProjNoSharding(self): - self._testProjNoSharding(use_gpu=False) - self._testProjNoSharding(use_gpu=True) - - def testProjSharding(self): - self._testProjSharding(use_gpu=False) - self._testProjSharding(use_gpu=True) - - def testShardNoShardEquivalentOutput(self): - self._testShardNoShardEquivalentOutput(use_gpu=False) - self._testShardNoShardEquivalentOutput(use_gpu=True) - - def testDoubleInput(self): - self._testDoubleInput(use_gpu=False) - self._testDoubleInput(use_gpu=True) - - def testDoubleInputWithDropoutAndDynamicCalculation(self): - self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=False) - self._testDoubleInputWithDropoutAndDynamicCalculation(use_gpu=True) - def testDynamicRNNAllowsUnknownTimeDimension(self): inputs = array_ops.placeholder(dtypes.float32, shape=[1, None, 20]) cell = rnn_cell.GRUCell(30) # Smoke test, this should not raise an error rnn.dynamic_rnn(cell, inputs, dtype=dtypes.float32) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithTupleStates(self): num_units = 3 input_size = 5 @@ -888,13 +855,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) cell = rnn_cell.LSTMCell( num_units, @@ -924,21 +898,34 @@ class LSTMTest(test.TestCase): self.assertEqual(state_dynamic[0], state_dynamic.c) self.assertEqual(state_dynamic[1], state_dynamic.h) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(state_static, - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(state_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + variables_lib.global_variables_initializer().run() + input_value = np.random.randn(batch_size, input_size) + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + state_static, feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + state_dynamic, feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) + @test_util.run_in_graph_and_eager_modes() def testDynamicRNNWithNestedTupleStates(self): num_units = 3 input_size = 5 @@ -946,13 +933,20 @@ class LSTMTest(test.TestCase): num_proj = 4 max_length = 8 sequence_length = [4, 6] + in_graph_mode = context.in_graph_mode() with self.test_session(graph=ops_lib.Graph()) as sess: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - inputs = max_length * [ - array_ops.placeholder( - dtypes.float32, shape=(None, input_size)) - ] + if in_graph_mode: + inputs = max_length * [ + array_ops.placeholder( + dtypes.float32, shape=(None, input_size)) + ] + else: + inputs = max_length * [ + constant_op.constant( + np.random.randn(batch_size, input_size).astype(np.float32)) + ] inputs_c = array_ops.stack(inputs) def _cell(i): @@ -993,43 +987,58 @@ class LSTMTest(test.TestCase): sequence_length=sequence_length, scope=scope) - variables_lib.global_variables_initializer().run() - - input_value = np.random.randn(batch_size, input_size) - outputs_static_v = sess.run(outputs_static, - feed_dict={inputs[0]: input_value}) - outputs_dynamic_v = sess.run(outputs_dynamic, - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(outputs_static_v, outputs_dynamic_v) - - state_static_v = sess.run(nest.flatten(state_static), - feed_dict={inputs[0]: input_value}) - state_dynamic_v = sess.run(nest.flatten(state_dynamic), - feed_dict={inputs[0]: input_value}) - self.assertAllEqual(np.hstack(state_static_v), np.hstack(state_dynamic_v)) + if in_graph_mode: + input_value = np.random.randn(batch_size, input_size) + variables_lib.global_variables_initializer().run() + outputs_static = sess.run( + outputs_static, feed_dict={ + inputs[0]: input_value + }) + outputs_dynamic = sess.run( + outputs_dynamic, feed_dict={ + inputs[0]: input_value + }) + state_static = sess.run( + nest.flatten(state_static), feed_dict={ + inputs[0]: input_value + }) + state_dynamic = sess.run( + nest.flatten(state_dynamic), feed_dict={ + inputs[0]: input_value + }) + + if in_graph_mode: + self.assertAllEqual(outputs_static, outputs_dynamic) + else: + self.assertAllEqual( + array_ops.stack(outputs_static).numpy(), outputs_dynamic.numpy()) + state_static = [s.numpy() for s in nest.flatten(state_static)] + state_dynamic = [s.numpy() for s in nest.flatten(state_dynamic)] + self.assertAllEqual(np.hstack(state_static), np.hstack(state_dynamic)) - def _testDynamicEquivalentToStaticRNN(self, use_gpu, use_sequence_length): + def _testDynamicEquivalentToStaticRNN(self, use_sequence_length): time_steps = 8 num_units = 3 num_proj = 4 input_size = 5 batch_size = 2 - input_values = np.random.randn(time_steps, batch_size, input_size) + input_values = np.random.randn(time_steps, batch_size, input_size).astype( + np.float32) if use_sequence_length: sequence_length = np.random.randint(0, time_steps, size=batch_size) else: sequence_length = None - ########### Step 1: Run static graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + in_graph_mode = context.in_graph_mode() + + # TODO(b/68017812): Eager ignores operation seeds, so we need to create a + # single cell and reuse it across the static and dynamic RNNs. Remove this + # special case once is fixed. + if not in_graph_mode: initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( num_units, use_peepholes=True, @@ -1037,63 +1046,85 @@ class LSTMTest(test.TestCase): num_proj=num_proj, state_is_tuple=False) + ########### Step 1: Run static graph and generate readouts + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) + inputs = array_ops.unstack(concat_inputs) + initializer = init_ops.random_uniform_initializer( + -0.01, 0.01, seed=self._seed) + + # TODO(akshayka): Remove special case once b/68017812 is fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) + with variable_scope.variable_scope("dynamic_scope"): outputs_static, state_static = rnn.static_rnn( cell, inputs, sequence_length=sequence_length, dtype=dtypes.float32) - feeds = {concat_inputs: input_values} - - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) - - # Generate gradients of sum of outputs w.r.t. inputs - static_gradients = gradients_impl.gradients( - outputs_static + [state_static], [concat_inputs]) - - # Generate gradients of individual outputs w.r.t. inputs - static_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - # pylint: disable=bad-builtin - static_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in [outputs_static[0], outputs_static[-1], state_static] - ]) - - # Test forward pass - values_static = sess.run(outputs_static, feed_dict=feeds) - (state_value_static,) = sess.run((state_static,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - static_grad_values = sess.run(static_gradients, feed_dict=feeds) - - static_individual_grad_values = sess.run(static_individual_gradients, - feed_dict=feeds) - - static_individual_var_grad_values = sess.run( - static_individual_variable_gradients, feed_dict=feeds) + if in_graph_mode: + # Generate gradients and run sessions to obtain outputs + feeds = {concat_inputs: input_values} + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Generate gradients of sum of outputs w.r.t. inputs + static_gradients = gradients_impl.gradients( + outputs_static + [state_static], [concat_inputs]) + # Generate gradients of individual outputs w.r.t. inputs + static_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + # pylint: disable=bad-builtin + static_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in [outputs_static[0], outputs_static[-1], state_static] + ]) + # Test forward pass + values_static = sess.run(outputs_static, feed_dict=feeds) + (state_value_static,) = sess.run((state_static,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + static_grad_values = sess.run(static_gradients, feed_dict=feeds) + + static_individual_grad_values = sess.run(static_individual_gradients, + feed_dict=feeds) + + static_individual_var_grad_values = sess.run( + static_individual_variable_gradients, feed_dict=feeds) ########## Step 2: Run dynamic graph and generate readouts - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: - concat_inputs = array_ops.placeholder( - dtypes.float32, shape=(time_steps, batch_size, input_size)) - inputs = array_ops.unstack(concat_inputs) + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: + if in_graph_mode: + concat_inputs = array_ops.placeholder( + dtypes.float32, shape=(time_steps, batch_size, input_size)) + else: + concat_inputs = constant_op.constant(input_values) initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=self._seed) - cell = rnn_cell.LSTMCell( - num_units, - use_peepholes=True, - initializer=initializer, - num_proj=num_proj, - state_is_tuple=False) + # TODO(akshayka): Remove this special case once b/68017812 is + # fixed. + if in_graph_mode: + cell = rnn_cell.LSTMCell( + num_units, + use_peepholes=True, + initializer=initializer, + num_proj=num_proj, + state_is_tuple=False) with variable_scope.variable_scope("dynamic_scope"): outputs_dynamic, state_dynamic = rnn.dynamic_rnn( @@ -1104,81 +1135,86 @@ class LSTMTest(test.TestCase): dtype=dtypes.float32) split_outputs_dynamic = array_ops.unstack(outputs_dynamic, time_steps) - feeds = {concat_inputs: input_values} + if in_graph_mode: + feeds = {concat_inputs: input_values} - # Initialize - variables_lib.global_variables_initializer().run(feed_dict=feeds) + # Initialize + variables_lib.global_variables_initializer().run(feed_dict=feeds) + + # Generate gradients of sum of outputs w.r.t. inputs + dynamic_gradients = gradients_impl.gradients( + split_outputs_dynamic + [state_dynamic], [concat_inputs]) - # Generate gradients of sum of outputs w.r.t. inputs - dynamic_gradients = gradients_impl.gradients( - split_outputs_dynamic + [state_dynamic], [concat_inputs]) - - # Generate gradients of several individual outputs w.r.t. inputs - dynamic_individual_gradients = nest.flatten([ - gradients_impl.gradients(y, [concat_inputs]) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Generate gradients of individual variables w.r.t. inputs - trainable_variables = ops_lib.get_collection( - ops_lib.GraphKeys.TRAINABLE_VARIABLES) - assert len(trainable_variables) > 1, ("Count of trainable variables: %d" % - len(trainable_variables)) - dynamic_individual_variable_gradients = nest.flatten([ - gradients_impl.gradients(y, trainable_variables) - for y in - [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] - ]) - - # Test forward pass - values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) - (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) - - # Test gradients to inputs and variables w.r.t. outputs & final state - dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) - - dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, - feed_dict=feeds) - - dynamic_individual_var_grad_values = sess.run( - dynamic_individual_variable_gradients, feed_dict=feeds) + # Generate gradients of several individual outputs w.r.t. inputs + dynamic_individual_gradients = nest.flatten([ + gradients_impl.gradients(y, [concat_inputs]) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) + + # Generate gradients of individual variables w.r.t. inputs + trainable_variables = ops_lib.get_collection( + ops_lib.GraphKeys.TRAINABLE_VARIABLES) + assert len(trainable_variables) > 1, ( + "Count of trainable variables: %d" % len(trainable_variables)) + dynamic_individual_variable_gradients = nest.flatten([ + gradients_impl.gradients(y, trainable_variables) + for y in + [split_outputs_dynamic[0], split_outputs_dynamic[-1], state_dynamic] + ]) + + # Test forward pass + values_dynamic = sess.run(split_outputs_dynamic, feed_dict=feeds) + (state_value_dynamic,) = sess.run((state_dynamic,), feed_dict=feeds) + + # Test gradients to inputs and variables w.r.t. outputs & final state + dynamic_grad_values = sess.run(dynamic_gradients, feed_dict=feeds) + + dynamic_individual_grad_values = sess.run(dynamic_individual_gradients, + feed_dict=feeds) + + dynamic_individual_var_grad_values = sess.run( + dynamic_individual_variable_gradients, feed_dict=feeds) ######### Step 3: Comparisons + if not in_graph_mode: + values_static = outputs_static + values_dynamic = split_outputs_dynamic + state_value_static = state_static + state_value_dynamic = state_dynamic + self.assertEqual(len(values_static), len(values_dynamic)) for (value_static, value_dynamic) in zip(values_static, values_dynamic): self.assertAllEqual(value_static, value_dynamic) self.assertAllEqual(state_value_static, state_value_dynamic) - self.assertAllEqual(static_grad_values, dynamic_grad_values) + if in_graph_mode: - self.assertEqual( - len(static_individual_grad_values), len(dynamic_individual_grad_values)) - self.assertEqual( - len(static_individual_var_grad_values), - len(dynamic_individual_var_grad_values)) + self.assertAllEqual(static_grad_values, dynamic_grad_values) - for i, (a, b) in enumerate( - zip(static_individual_grad_values, dynamic_individual_grad_values)): - tf_logging.info("Comparing individual gradients iteration %d" % i) - self.assertAllEqual(a, b) + self.assertEqual( + len(static_individual_grad_values), + len(dynamic_individual_grad_values)) + self.assertEqual( + len(static_individual_var_grad_values), + len(dynamic_individual_var_grad_values)) - for i, (a, b) in enumerate( - zip(static_individual_var_grad_values, - dynamic_individual_var_grad_values)): - tf_logging.info("Comparing individual variable gradients iteration %d" % - i) - self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_grad_values, dynamic_individual_grad_values)): + tf_logging.info("Comparing individual gradients iteration %d" % i) + self.assertAllEqual(a, b) + for i, (a, b) in enumerate( + zip(static_individual_var_grad_values, + dynamic_individual_var_grad_values)): + tf_logging.info("Comparing individual variable gradients iteration %d" % + i) + self.assertAllEqual(a, b) + + @test_util.run_in_graph_and_eager_modes() def testDynamicEquivalentToStaticRNN(self): - self._testDynamicEquivalentToStaticRNN( - use_gpu=False, use_sequence_length=False) - self._testDynamicEquivalentToStaticRNN( - use_gpu=True, use_sequence_length=False) - self._testDynamicEquivalentToStaticRNN( - use_gpu=False, use_sequence_length=True) - self._testDynamicEquivalentToStaticRNN( - use_gpu=True, use_sequence_length=True) + self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) + self._testDynamicEquivalentToStaticRNN(use_sequence_length=False) class BidirectionalRNNTest(test.TestCase): @@ -1188,7 +1224,6 @@ class BidirectionalRNNTest(test.TestCase): np.random.seed(self._seed) def _createBidirectionalRNN(self, - use_gpu, use_shape, use_sequence_length, scope=None): @@ -1227,10 +1262,10 @@ class BidirectionalRNNTest(test.TestCase): return input_value, inputs, outputs, state_fw, state_bw, sequence_length - def _testBidirectionalRNN(self, use_gpu, use_shape): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + def _testBidirectionalRNN(self, use_shape): + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( - self._createBidirectionalRNN(use_gpu, use_shape, True)) + self._createBidirectionalRNN(use_shape, True)) variables_lib.global_variables_initializer().run() # Run with pre-specified sequence length of 2, 3 out, s_fw, s_bw = sess.run( @@ -1272,10 +1307,10 @@ class BidirectionalRNNTest(test.TestCase): # exactly the same self.assertAllClose(s_fw, s_bw) - def _testBidirectionalRNNWithoutSequenceLength(self, use_gpu, use_shape): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + def _testBidirectionalRNNWithoutSequenceLength(self, use_shape): + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, _ = ( - self._createBidirectionalRNN(use_gpu, use_shape, False)) + self._createBidirectionalRNN(use_shape, False)) variables_lib.global_variables_initializer().run() out, s_fw, s_bw = sess.run([outputs, state_fw, state_bw], feed_dict={inputs[0]: input_value}) @@ -1302,23 +1337,14 @@ class BidirectionalRNNTest(test.TestCase): self.assertAllClose(s_fw, s_bw) def testBidirectionalRNN(self): - self._testBidirectionalRNN(use_gpu=False, use_shape=False) - self._testBidirectionalRNN(use_gpu=True, use_shape=False) - self._testBidirectionalRNN(use_gpu=False, use_shape=True) - self._testBidirectionalRNN(use_gpu=True, use_shape=True) + self._testBidirectionalRNN(use_shape=False) + self._testBidirectionalRNN(use_shape=True) def testBidirectionalRNNWithoutSequenceLength(self): - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=False, use_shape=False) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=True, use_shape=False) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=False, use_shape=True) - self._testBidirectionalRNNWithoutSequenceLength( - use_gpu=True, use_shape=True) + self._testBidirectionalRNNWithoutSequenceLength(use_shape=False) + self._testBidirectionalRNNWithoutSequenceLength(use_shape=True) def _createBidirectionalDynamicRNN(self, - use_gpu, use_shape, use_state_tuple, use_time_major, @@ -1366,11 +1392,11 @@ class BidirectionalRNNTest(test.TestCase): return input_value, inputs, outputs, state_fw, state_bw, sequence_length - def _testBidirectionalDynamicRNN(self, use_gpu, use_shape, use_state_tuple, + def _testBidirectionalDynamicRNN(self, use_shape, use_state_tuple, use_time_major, use_sequence_length): - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: input_value, inputs, outputs, state_fw, state_bw, sequence_length = ( - self._createBidirectionalDynamicRNN(use_gpu, use_shape, + self._createBidirectionalDynamicRNN(use_shape, use_state_tuple, use_time_major, use_sequence_length)) variables_lib.global_variables_initializer().run() @@ -1435,14 +1461,13 @@ class BidirectionalRNNTest(test.TestCase): def testBidirectionalDynamicRNN(self): # Generate 2^5 option values # from [True, True, True, True, True] to [False, False, False, False, False] - options = itertools.product([True, False], repeat=5) + options = itertools.product([True, False], repeat=4) for option in options: self._testBidirectionalDynamicRNN( - use_gpu=option[0], - use_shape=option[1], - use_state_tuple=option[2], - use_time_major=option[3], - use_sequence_length=option[4]) + use_shape=option[0], + use_state_tuple=option[1], + use_time_major=option[2], + use_sequence_length=option[3]) def _testScope(self, factory, prefix="prefix", use_outer_scope=True): # REMARKS: factory(scope) is a function accepting a scope @@ -1471,7 +1496,7 @@ class BidirectionalRNNTest(test.TestCase): def factory(scope): return self._createBidirectionalRNN( - use_gpu=True, use_shape=True, use_sequence_length=True, scope=scope) + use_shape=True, use_sequence_length=True, scope=scope) self._testScope(factory, use_outer_scope=True) self._testScope(factory, use_outer_scope=False) @@ -1483,7 +1508,6 @@ class BidirectionalRNNTest(test.TestCase): def factory(scope): return self._createBidirectionalDynamicRNN( - use_gpu=True, use_shape=True, use_state_tuple=True, use_sequence_length=True, @@ -1761,7 +1785,7 @@ class GRUTest(test.TestCase): self._seed = 23489 np.random.seed(self._seed) - def _testDynamic(self, use_gpu): + def testDynamic(self): time_steps = 8 num_units = 3 input_size = 5 @@ -1771,7 +1795,7 @@ class GRUTest(test.TestCase): sequence_length = np.random.randint(0, time_steps, size=batch_size) - with self.test_session(use_gpu=use_gpu, graph=ops_lib.Graph()) as sess: + with self.test_session(use_gpu=True, graph=ops_lib.Graph()) as sess: concat_inputs = array_ops.placeholder( dtypes.float32, shape=(time_steps, batch_size, input_size)) @@ -1792,10 +1816,6 @@ class GRUTest(test.TestCase): sess.run([outputs_dynamic, state_dynamic], feed_dict=feeds) - def testDynamic(self): - self._testDynamic(use_gpu=False) - self._testDynamic(use_gpu=True) - def _testScope(self, factory, prefix="prefix", use_outer_scope=True): with self.test_session(use_gpu=True, graph=ops_lib.Graph()): if use_outer_scope: @@ -2203,6 +2223,17 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): return run_metadata + def _retrieve_cpu_gpu_stats(self, run_metadata): + cpu_stats = None + gpu_stats = None + step_stats = run_metadata.step_stats + for ds in step_stats.dev_stats: + if "cpu:0" in ds.device[-5:].lower(): + cpu_stats = ds.node_stats + if "gpu:0" == ds.device[-5:].lower(): + gpu_stats = ds.node_stats + return cpu_stats, gpu_stats + def testRNNOnCPUCellOnGPU(self): if not test.is_gpu_available(): return # Test requires access to a GPU @@ -2210,10 +2241,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2236,10 +2264,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): run_metadata = self._execute_rnn_on( rnn_device="/cpu:0", cell_device="/cpu:0", input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) @@ -2255,10 +2280,7 @@ class TensorArrayOnCorrectDeviceTest(test.TestCase): gpu_dev = test.gpu_device_name() run_metadata = self._execute_rnn_on( input_device=gpu_dev) - step_stats = run_metadata.step_stats - ix = 0 if (gpu_dev in step_stats.dev_stats[0].device) else 1 - gpu_stats = step_stats.dev_stats[ix].node_stats - cpu_stats = step_stats.dev_stats[1 - ix].node_stats + cpu_stats, gpu_stats = self._retrieve_cpu_gpu_stats(run_metadata) def _assert_in(op_str, in_stats, out_stats): self.assertTrue(any(op_str in s.node_name for s in in_stats)) diff --git a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py index 4239e32ab93043c5054e5382e67e79047b9644bb..b865466cc75aa67fcd192f7726f65141409b896a 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/gru_ops_test.py @@ -18,10 +18,9 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import time - import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import gru_ops from tensorflow.python.client import session from tensorflow.python.framework import dtypes @@ -333,20 +332,6 @@ class GRUBlockCellTest(test.TestCase): #### Benchmarking GRUBlockCell vs GRUCell. -def time_taken_by_op(op, sess, num_runs=50): - """Time taken by the Op.""" - for _ in range(2): - sess.run([op]) - - start_time = time.time() - for _ in range(num_runs): - sess.run([op]) - - end_time = time.time() - time_taken = end_time - start_time - return time_taken - - def training_gru_block_vs_gru_cell(batch_size, cell_size, input_size, @@ -357,7 +342,7 @@ def training_gru_block_vs_gru_cell(batch_size, ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: # Specify the device which is been used. - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -387,7 +372,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - basic_time_training = time_taken_by_op(optimizer, sess, iters) + basic_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) # Output from the basic GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -406,7 +392,8 @@ def training_gru_block_vs_gru_cell(batch_size, learning_rate).minimize(cost) # time for a training step. - block_time_training = time_taken_by_op(optimizer, sess, iters) + block_time_training = benchmarking.seconds_per_run( + optimizer, sess, iters) performance_training = ( basic_time_training - block_time_training) * 100 / basic_time_training @@ -429,7 +416,7 @@ def inference_gru_block_vs_gru_cell(batch_size, """Benchmark inference speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): # Random initializers. seed = 1994 @@ -451,7 +438,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - basic_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + basic_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -463,7 +451,8 @@ def inference_gru_block_vs_gru_cell(batch_size, time_major=True, dtype=dtypes.float32) sess.run([variables.global_variables_initializer()]) - block_time_inference = time_taken_by_op(outputs_dynamic, sess, iters) + block_time_inference = benchmarking.seconds_per_run( + outputs_dynamic, sess, iters) performance_inference = (basic_time_inference - block_time_inference ) * 100 / basic_time_inference @@ -484,7 +473,7 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, """Benchmark single bprop step speed between GRUBlockCell vs GRUCell.""" ops.reset_default_graph() with session.Session(graph=ops.Graph()) as sess: - with ops.device("/cpu:0" if not use_gpu else "/device:GPU:0"): + with benchmarking.device(use_gpu): initializer = init_ops.random_uniform_initializer(-1, 1, seed=1989) # Inputs x = vs.get_variable("x", [batch_size, input_size]) @@ -496,7 +485,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - basic_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + basic_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) # Output from the block GRU cell implementation. with vs.variable_scope("block", initializer=initializer): @@ -504,7 +494,8 @@ def single_bprop_step_gru_block_vs_gru_cell(batch_size, array_ops.identity(h)) sess.run([variables.global_variables_initializer()]) grad_output_wrt_input = gradients_impl.gradients([output], h) - block_time_bprop = time_taken_by_op(grad_output_wrt_input, sess, iters) + block_time_bprop = benchmarking.seconds_per_run(grad_output_wrt_input, + sess, iters) performance_inference = ( basic_time_bprop - block_time_bprop) * 100 / basic_time_bprop @@ -526,23 +517,29 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_training, block_time_training, performance_training[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = training_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = training_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_training_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkInferenceBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -551,23 +548,28 @@ class BenchmarkGRUBlock(test.Benchmark): "batch_size, cell_size, input_size, time_steps, GPU, " "basic_time_inference, block_time_inference, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - for time_steps in [50]: - basic_time, block_time = inference_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, time_steps, use_gpu, iters) - self.report_benchmark( - name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % - (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" - % (batch_size, cell_size, input_size, time_steps, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512], + "time_steps": [50] + }): + basic_time, block_time = inference_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_inference_time_BS%i_CS%i_IS%i_TS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["time_steps"], config["use_gpu"]), + iters=iters, + wall_time=block_time) def benchmarkSingleBpropStepBlockGRUVsGRUCell(self): print("--------------------------------------------------------------") @@ -575,22 +577,27 @@ class BenchmarkGRUBlock(test.Benchmark): print("batch_size, cell_size, input_size, GPU, basic_time, " "block_time, performance_inference[%]") iters = 10 - for use_gpu in [True, False]: - for batch_size in [1, 32, 128]: - for cell_size in [128, 512]: - for input_size in [128, 512]: - basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( - batch_size, cell_size, input_size, use_gpu, iters) - self.report_benchmark( - name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % - (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=basic_time) - self.report_benchmark( - name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" - % (batch_size, cell_size, input_size, use_gpu), - iters=iters, - wall_time=block_time) + for config in benchmarking.dict_product({ + "use_gpu": [True, False], + "batch_size": [1, 32, 128], + "cell_size": [128, 512], + "input_size": [128, 512] + }): + basic_time, block_time = single_bprop_step_gru_block_vs_gru_cell( + config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"], iters) + self.report_benchmark( + name="GRUCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=basic_time) + self.report_benchmark( + name="GRUBlockCell_Bprop_single_step_time_BS%i_CS%i_IS%i_gpu_%s" % + (config["batch_size"], config["cell_size"], config["input_size"], + config["use_gpu"]), + iters=iters, + wall_time=block_time) print("--------------------------------------------------------------") diff --git a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py index 0ec37411f5f3d9b6687c077bf967b046068644ab..a288072ae5da0751f1999128029f38bea933490e 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/lstm_ops_test.py @@ -20,7 +20,9 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.rnn.python.kernel_tests import benchmarking from tensorflow.contrib.rnn.python.ops import lstm_ops +from tensorflow.python.client import session from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -36,6 +38,111 @@ from tensorflow.python.platform import test block_lstm = lstm_ops._block_lstm # pylint: disable=protected-access +def blocks_match(sess, use_peephole): + batch_size = 2 + input_size = 3 + cell_size = 4 + sequence_length = 4 + + inputs = [] + for _ in range(sequence_length): + inp = ops.convert_to_tensor( + np.random.randn(batch_size, input_size), dtype=dtypes.float32) + inputs.append(inp) + + initializer = init_ops.random_uniform_initializer(-0.01, 0.01, seed=19890212) + + with variable_scope.variable_scope("test", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + if use_peephole: + wci = variable_scope.get_variable( + "rnn/lstm_cell/w_i_diag", shape=[cell_size], dtype=dtypes.float32) + wcf = variable_scope.get_variable( + "rnn/lstm_cell/w_f_diag", shape=[cell_size], dtype=dtypes.float32) + wco = variable_scope.get_variable( + "rnn/lstm_cell/w_o_diag", shape=[cell_size], dtype=dtypes.float32) + + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + if use_peephole: + wci_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_i_diag", + initializer=wci.initialized_value()) + wcf_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_f_diag", + initializer=wcf.initialized_value()) + wco_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/w_o_diag", + initializer=wco.initialized_value()) + w_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/kernel", + initializer=w.initialized_value()) + b_block = variable_scope.get_variable( + "rnn/lstm_cell/lstm_block_wrapper/bias", + initializer=b.initialized_value()) + + basic_cell = rnn_cell.LSTMCell( + cell_size, use_peepholes=use_peephole, state_is_tuple=True, reuse=True) + basic_outputs_op, basic_state_op = rnn.static_rnn( + basic_cell, inputs, dtype=dtypes.float32) + + if use_peephole: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + wci=wci, + wcf=wcf, + wco=wco, + cell_clip=0, + use_peephole=True) + else: + _, _, _, _, _, _, block_outputs_op = block_lstm( + ops.convert_to_tensor(sequence_length, dtype=dtypes.int64), + inputs, + w, + b, + cell_clip=0) + + with variable_scope.variable_scope("rnn/lstm_cell", reuse=True): + fused_cell = lstm_ops.LSTMBlockFusedCell( + cell_size, cell_clip=0, use_peephole=use_peephole) + fused_outputs_op, fused_state_op = fused_cell( + inputs, dtype=dtypes.float32) + + sess.run([variables.global_variables_initializer()]) + basic_outputs, basic_state = sess.run([basic_outputs_op, basic_state_op[0]]) + basic_grads = sess.run(gradients_impl.gradients(basic_outputs_op, inputs)) + xs = [w, b] + if use_peephole: + xs += [wci, wcf, wco] + basic_wgrads = sess.run(gradients_impl.gradients(basic_outputs_op, xs)) + + block_outputs = sess.run(block_outputs_op) + block_grads = sess.run(gradients_impl.gradients(block_outputs_op, inputs)) + block_wgrads = sess.run(gradients_impl.gradients(block_outputs_op, xs)) + + xs = [w_block, b_block] + if use_peephole: + xs += [wci_block, wcf_block, wco_block] + fused_outputs, fused_state = sess.run([fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run(gradients_impl.gradients(fused_outputs_op, xs)) + + return (basic_state, fused_state, basic_outputs, block_outputs, + fused_outputs, basic_grads, block_grads, fused_grads, basic_wgrads, + block_wgrads, fused_wgrads) + + class LSTMBlockCellTest(test.TestCase): def testNoneDimsWithDynamicRNN(self): @@ -225,164 +332,39 @@ class LSTMBlockCellTest(test.TestCase): def testLSTMBasicToBlock(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - cell_clip=0) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run(gradients_impl.gradients(outputs, [w, b])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=False) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMBasicToBlockPeeping(self): with self.test_session(use_gpu=True) as sess: - batch_size = 2 - input_size = 3 - cell_size = 4 - sequence_length = 5 - - inputs = [] - for _ in range(sequence_length): - inp = ops.convert_to_tensor( - np.random.randn(batch_size, input_size), dtype=dtypes.float32) - inputs.append(inp) - - initializer = init_ops.random_uniform_initializer( - -0.01, 0.01, seed=19890212) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.LSTMCell( - cell_size, use_peepholes=True, state_is_tuple=True) - outputs, state = rnn.static_rnn(cell, inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - - with variable_scope.variable_scope("block", initializer=initializer): - w = variable_scope.get_variable( - "w", - shape=[input_size + cell_size, cell_size * 4], - dtype=dtypes.float32) - b = variable_scope.get_variable( - "b", - shape=[cell_size * 4], - dtype=dtypes.float32, - initializer=init_ops.zeros_initializer()) - - wci = variable_scope.get_variable( - "wci", shape=[cell_size], dtype=dtypes.float32) - wcf = variable_scope.get_variable( - "wcf", shape=[cell_size], dtype=dtypes.float32) - wco = variable_scope.get_variable( - "wco", shape=[cell_size], dtype=dtypes.float32) - - _, _, _, _, _, _, outputs = block_lstm( - ops.convert_to_tensor( - sequence_length, dtype=dtypes.int64), - inputs, - w, - b, - wci=wci, - wcf=wcf, - wco=wco, - cell_clip=0, - use_peephole=True) - - sess.run([variables.global_variables_initializer()]) - block_outputs = sess.run(outputs) - block_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - block_wgrads = sess.run( - gradients_impl.gradients(outputs, [w, b, wci, wcf, wco])) + (basic_state, fused_state, basic_outputs, block_outputs, fused_outputs, + basic_grads, block_grads, fused_grads, basic_wgrads, block_wgrads, + fused_wgrads) = blocks_match( + sess, use_peephole=True) self.assertAllClose(basic_outputs, block_outputs) self.assertAllClose(basic_grads, block_grads) for basic, block in zip(basic_wgrads, block_wgrads): - self.assertAllClose(basic, block, rtol=1e-2, atol=1e-2) - - with variable_scope.variable_scope("fused", initializer=initializer): - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=True) - outputs, state = cell(inputs, dtype=dtypes.float32) - - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + self.assertAllClose(basic, block, rtol=1e-6, atol=1e-6) self.assertAllClose(basic_outputs, fused_outputs) self.assertAllClose(basic_state, fused_state) self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + for basic, fused in zip(block_wgrads, fused_wgrads): + self.assertAllClose(basic, fused, rtol=1e-6, atol=1e-6) def testLSTMFusedSequenceLengths(self): """Verify proper support for sequence lengths in LSTMBlockFusedCell.""" @@ -401,45 +383,40 @@ class LSTMBlockCellTest(test.TestCase): initializer = init_ops.random_uniform_initializer( -0.01, 0.01, seed=19890213) - with variable_scope.variable_scope("basic", initializer=initializer): - cell = rnn_cell.BasicLSTMCell(cell_size, state_is_tuple=True) - outputs, state = rnn.static_rnn( - cell, inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - basic_outputs, basic_state = sess.run([outputs, state[0]]) - basic_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - basic_wgrads = sess.run( - gradients_impl.gradients(outputs, variables.trainable_variables())) - with variable_scope.variable_scope("fused", initializer=initializer): + with variable_scope.variable_scope( + "lstm_block_wrapper", initializer=initializer): + # magic naming so that the cells pick up these variables and resuse them + variable_scope.get_variable( + "kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + + variable_scope.get_variable( + "bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + + with variable_scope.variable_scope( + variable_scope.get_variable_scope(), reuse=True): cell = lstm_ops.LSTMBlockFusedCell( cell_size, cell_clip=0, use_peephole=False) - outputs, state = cell( - inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - sess.run([variables.global_variables_initializer()]) - fused_outputs, fused_state = sess.run([outputs, state[0]]) - fused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - fused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("fused/") - ] - fused_wgrads = sess.run(gradients_impl.gradients(outputs, fused_vars)) + fused_outputs_op, fused_state_op = cell( + inputs, dtype=dtypes.float32, sequence_length=seq_lengths) - self.assertAllClose(basic_outputs, fused_outputs) - self.assertAllClose(basic_state, fused_state) - self.assertAllClose(basic_grads, fused_grads) - for basic, fused in zip(basic_wgrads, fused_wgrads): - self.assertAllClose(basic, fused, rtol=1e-2, atol=1e-2) + cell_vars = [ + v for v in variables.trainable_variables() + if v.name.endswith("kernel") or v.name.endswith("bias") + ] # Verify that state propagation works if we turn our sequence into # tiny (single-time) subsequences, i.e. unfuse the cell + unfused_outputs_op = [] + state = None with variable_scope.variable_scope( - "unfused", initializer=initializer) as vs: - cell = lstm_ops.LSTMBlockFusedCell( - cell_size, cell_clip=0, use_peephole=False) - outputs = [] - state = None + variable_scope.get_variable_scope(), reuse=True): for i, inp in enumerate(inputs): lengths = [int(i < l) for l in seq_lengths.eval()] output, state = cell( @@ -447,25 +424,136 @@ class LSTMBlockCellTest(test.TestCase): initial_state=state, dtype=dtypes.float32, sequence_length=lengths) - vs.reuse_variables() - outputs.append(output[0]) - outputs = array_ops.stack(outputs) - - sess.run([variables.global_variables_initializer()]) - unfused_outputs, unfused_state = sess.run([outputs, state[0]]) - unfused_grads = sess.run(gradients_impl.gradients(outputs, inputs)) - unfused_vars = [ - v for v in variables.trainable_variables() - if v.name.startswith("unfused/") - ] - unfused_wgrads = sess.run( - gradients_impl.gradients(outputs, unfused_vars)) - - self.assertAllClose(basic_outputs, unfused_outputs) - self.assertAllClose(basic_state, unfused_state) - self.assertAllClose(basic_grads, unfused_grads) - for basic, unfused in zip(basic_wgrads, unfused_wgrads): - self.assertAllClose(basic, unfused, rtol=1e-2, atol=1e-2) + unfused_outputs_op.append(output[0]) + unfused_outputs_op = array_ops.stack(unfused_outputs_op) + + sess.run([variables.global_variables_initializer()]) + unfused_outputs, unfused_state = sess.run([unfused_outputs_op, state[0]]) + unfused_grads = sess.run( + gradients_impl.gradients(unfused_outputs_op, inputs)) + unfused_wgrads = sess.run( + gradients_impl.gradients(unfused_outputs_op, cell_vars)) + + fused_outputs, fused_state = sess.run( + [fused_outputs_op, fused_state_op[0]]) + fused_grads = sess.run(gradients_impl.gradients(fused_outputs_op, inputs)) + fused_wgrads = sess.run( + gradients_impl.gradients(fused_outputs_op, cell_vars)) + + self.assertAllClose(fused_outputs, unfused_outputs) + self.assertAllClose(fused_state, unfused_state) + self.assertAllClose(fused_grads, unfused_grads) + for fused, unfused in zip(fused_wgrads, unfused_wgrads): + self.assertAllClose(fused, unfused, rtol=1e-6, atol=1e-6) + +#### Benchmarking. + + +class BenchmarkLSTMBlock(test.Benchmark): + + def benchmarkLSTMBlockCellFpropWithDynamicRNN(self): + print("BlockLSTMCell forward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + inputs = variable_scope.get_variable( + "x", + [config["time_steps"], config["batch_size"], config["cell_size"]]) + cell = lstm_ops.LSTMBlockCell(config["cell_size"]) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(outputs, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + config["batch_size"], config["cell_size"], config["cell_size"], + config["time_steps"], config["use_gpu"], wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_fprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) + + def benchmarkLSTMBlockCellBpropWithDynamicRNN(self): + print("BlockLSTMCell backward propagation via dynamic_rnn().") + print("--------------------------------------------------------------") + print("LSTMBlockCell Seconds per inference.") + print("batch_size,cell_size,input_size,time_steps,use_gpu,wall_time") + iters = 10 + for config in benchmarking.dict_product({ + "batch_size": [1, 8, 13, 32, 67, 128], + "cell_size": [128, 250, 512, 650, 1024, 1350], + "time_steps": [40], + "use_gpu": [True, False] + }): + with ops.Graph().as_default(): + with benchmarking.device(use_gpu=config["use_gpu"]): + time_steps = config["time_steps"] + batch_size = config["batch_size"] + cell_size = input_size = config["cell_size"] + inputs = variable_scope.get_variable( + "x", [time_steps, batch_size, cell_size], + trainable=False, + dtype=dtypes.float32) + with variable_scope.variable_scope( + "rnn", reuse=variable_scope.AUTO_REUSE): + w = variable_scope.get_variable( + "rnn/lstm_cell/kernel", + shape=[input_size + cell_size, cell_size * 4], + dtype=dtypes.float32) + b = variable_scope.get_variable( + "rnn/lstm_cell/bias", + shape=[cell_size * 4], + dtype=dtypes.float32, + initializer=init_ops.zeros_initializer()) + cell = lstm_ops.LSTMBlockCell(cell_size) + outputs = rnn.dynamic_rnn( + cell, inputs, time_major=True, dtype=dtypes.float32) + grads = gradients_impl.gradients(outputs, [inputs, w, b]) + init_op = variables.global_variables_initializer() + + with session.Session() as sess: + sess.run(init_op) + wall_time = benchmarking.seconds_per_run(grads, sess, iters) + + # Print to stdout. If the TEST_REPORT_FILE_PREFIX environment variable + # is set, this will produce a copy-paste-able CSV file. + print(",".join( + map(str, [ + batch_size, cell_size, cell_size, time_steps, config["use_gpu"], + wall_time + ]))) + benchmark_name_template = "_".join([ + "LSTMBlockCell_bprop", "BS%(batch_size)i", "CS%(cell_size)i", + "IS%(cell_size)i", "TS%(time_steps)i", "gpu_%(use_gpu)s" + ]) + + self.report_benchmark( + name=benchmark_name_template % config, + iters=iters, + wall_time=wall_time, + extras=config) if __name__ == "__main__": diff --git a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py index ebd4564f1204cd69527633e16e67cda3f3a8407e..b4a5f2d7ebaaa7fd916fb7129db7e2bdbee19706 100644 --- a/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py +++ b/tensorflow/contrib/rnn/python/kernel_tests/rnn_cell_test.py @@ -37,6 +37,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn from tensorflow.python.ops import rnn_cell +from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -1275,6 +1276,49 @@ class LayerNormBasicLSTMCellTest(test.TestCase): self.assertAllClose(res[2].c, expected_c1, 1e-5) self.assertAllClose(res[2].h, expected_h1, 1e-5) + + def testBasicLSTMCellWithStateTupleLayerNorm(self): + """The results of LSTMCell and LayerNormBasicLSTMCell + should be same. """ + with self.test_session() as sess: + with variable_scope.variable_scope( + "root", initializer=init_ops.constant_initializer(0.5)): + x = array_ops.zeros([1, 2]) + c0 = array_ops.zeros([1, 2]) + h0 = array_ops.zeros([1, 2]) + state0 = rnn_cell_impl.LSTMStateTuple(c0, h0) + c1 = array_ops.zeros([1, 2]) + h1 = array_ops.zeros([1, 2]) + state1 = rnn_cell_impl.LSTMStateTuple(c1, h1) + cell = rnn_cell_impl.MultiRNNCell( + [contrib_rnn_cell.LayerNormLSTMCell( + 2, + layer_norm=True, + norm_gain=1.0, + norm_shift=0.0) for _ in range(2)]) + h, (s0, s1) = cell(x, (state0, state1)) + sess.run([variables.global_variables_initializer()]) + res = sess.run([h, s0, s1], { + x.name: np.array([[1., 1.]]), + c0.name: 0.1 * np.asarray([[0, 1]]), + h0.name: 0.1 * np.asarray([[2, 3]]), + c1.name: 0.1 * np.asarray([[4, 5]]), + h1.name: 0.1 * np.asarray([[6, 7]]), + }) + + expected_h = np.array([[-0.38079708, 0.38079708]]) + expected_h0 = np.array([[-0.38079708, 0.38079708]]) + expected_c0 = np.array([[-1.0, 1.0]]) + expected_h1 = np.array([[-0.38079708, 0.38079708]]) + expected_c1 = np.array([[-1.0, 1.0]]) + + self.assertEqual(len(res), 3) + self.assertAllClose(res[0], expected_h, 1e-5) + self.assertAllClose(res[1].c, expected_c0, 1e-5) + self.assertAllClose(res[1].h, expected_h0, 1e-5) + self.assertAllClose(res[2].c, expected_c1, 1e-5) + self.assertAllClose(res[2].h, expected_h1, 1e-5) + def testBasicLSTMCellWithDropout(self): def _is_close(x, y, digits=4): diff --git a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py index 6b6bd503ceec8d0d7cd2bca5b7ec548fcf08445c..8109ebc718353300f94536c5d7ae3332da584a1d 100644 --- a/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/core_rnn_cell.py @@ -24,17 +24,169 @@ from __future__ import print_function import math +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import nest -RNNCell = rnn_cell_impl.RNNCell # pylint: disable=invalid-name -_linear = rnn_cell_impl._linear # pylint: disable=invalid-name, protected-access -_like_rnncell = rnn_cell_impl._like_rnncell # pylint: disable=invalid-name, protected-access + +# pylint: disable=protected-access,invalid-name +RNNCell = rnn_cell_impl.RNNCell +_like_rnncell = rnn_cell_impl._like_rnncell +_WEIGHTS_VARIABLE_NAME = rnn_cell_impl._WEIGHTS_VARIABLE_NAME +_BIAS_VARIABLE_NAME = rnn_cell_impl._BIAS_VARIABLE_NAME +# pylint: enable=protected-access,invalid-name + + +class _Linear(object): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch, n, Tensors. + output_size: int, second dimension of weight variable. + dtype: data type for variables. + build_bias: boolean, whether to build a bias variable. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + + Raises: + ValueError: if inputs_shape is wrong. + """ + + def __init__(self, + args, + output_size, + build_bias, + bias_initializer=None, + kernel_initializer=None): + self._build_bias = build_bias + + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + self._is_sequence = False + else: + self._is_sequence = True + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + self._weights = vs.get_variable( + _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if build_bias: + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + self._biases = vs.get_variable( + _BIAS_VARIABLE_NAME, [output_size], + dtype=dtype, + initializer=bias_initializer) + + def __call__(self, args): + if not self._is_sequence: + args = [args] + + if len(args) == 1: + res = math_ops.matmul(args[0], self._weights) + else: + # Explicitly creating a one for a minor performance improvement. + one = constant_op.constant(1, dtype=dtypes.int32) + res = math_ops.matmul(array_ops.concat(args, one), self._weights) + if self._build_bias: + res = nn_ops.bias_add(res, self._biases) + return res + + +# TODO(xpan): Remove this function in a follow up. +def _linear(args, + output_size, + bias, + bias_initializer=None, + kernel_initializer=None): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a variable. + + Args: + args: a 2D Tensor or a list of 2D, batch, n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + + Returns: + A 2D Tensor with shape `[batch, output_size]` equal to + sum_i(args[i] * W[i]), where W[i]s are newly created matrices. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + _WEIGHTS_VARIABLE_NAME, [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(args, 1), weights) + if not bias: + return res + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + biases = vs.get_variable( + _BIAS_VARIABLE_NAME, [output_size], + dtype=dtype, + initializer=bias_initializer) + return nn_ops.bias_add(res, biases) class EmbeddingWrapper(RNNCell): @@ -154,6 +306,7 @@ class InputProjectionWrapper(RNNCell): self._cell = cell self._num_proj = num_proj self._activation = activation + self._linear = None @property def state_size(self): @@ -170,7 +323,9 @@ class InputProjectionWrapper(RNNCell): def call(self, inputs, state): """Run the input projection and then the cell.""" # Default scope: "InputProjectionWrapper" - projected = _linear(inputs, self._num_proj, True) + if self._linear is None: + self._linear = _Linear(inputs, self._num_proj, True) + projected = self._linear(inputs) if self._activation: projected = self._activation(projected) return self._cell(projected, state) @@ -208,6 +363,7 @@ class OutputProjectionWrapper(RNNCell): self._cell = cell self._output_size = output_size self._activation = activation + self._linear = None @property def state_size(self): @@ -224,7 +380,9 @@ class OutputProjectionWrapper(RNNCell): def call(self, inputs, state): """Run the cell and output projection on inputs, starting from state.""" output, res_state = self._cell(inputs, state) - projected = _linear(output, self._output_size, True) + if self._linear is None: + self._linear = _Linear(output, self._output_size, True) + projected = self._linear(output) if self._activation: projected = self._activation(projected) return projected, res_state diff --git a/tensorflow/contrib/rnn/python/ops/lstm_ops.py b/tensorflow/contrib/rnn/python/ops/lstm_ops.py index f591f7c84e50660ccddbe13e31a32f6bc273c460..df910a3423083972bdee42bec10733e37b8e5f96 100644 --- a/tensorflow/contrib/rnn/python/ops/lstm_ops.py +++ b/tensorflow/contrib/rnn/python/ops/lstm_ops.py @@ -92,7 +92,7 @@ def _lstm_block_cell(x, wco: A `Tensor`. Must have the same type as `x`. The weight matrix for output gate peephole connection. forget_bias: An optional `float`. Defaults to `1`. The forget gate bias. - cell_clip: An optional `float`. Defaults to `3`. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). Value to clip the 'cs' value to. Disable by setting to negative value. use_peephole: An optional `bool`. Defaults to `False`. Whether to use peephole weights. @@ -116,8 +116,8 @@ def _lstm_block_cell(x, if cell_size is None: raise ValueError("cell_size from `cs_prev` should not be None.") wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access return gen_lstm_ops.lstm_block_cell( @@ -126,11 +126,11 @@ def _lstm_block_cell(x, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, - cell_clip=cell_clip, + cell_clip=cell_clip if cell_clip is not None else -1, use_peephole=use_peephole, name=name) # pylint: enable=protected-access @@ -162,7 +162,7 @@ def _block_lstm(seq_len_max, wcf: A `Tensor`. Must have the same type as `x`. wco: A `Tensor`. Must have the same type as `x`. forget_bias: An optional `float`. Defaults to `1`. - cell_clip: An optional `float`. Defaults to `3`. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). use_peephole: An optional `bool`. Defaults to `False`. name: A name for the operation (optional). @@ -201,8 +201,8 @@ def _block_lstm(seq_len_max, h_prev = zero_state if wci is None: wci = array_ops.constant(0, dtype=dtypes.float32, shape=[cell_size]) - wco = wci wcf = wci + wco = wci # pylint: disable=protected-access i, cs, f, o, ci, co, h = gen_lstm_ops.block_lstm( @@ -212,11 +212,11 @@ def _block_lstm(seq_len_max, h_prev=h_prev, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=forget_bias, - cell_clip=cell_clip, + cell_clip=cell_clip if cell_clip is not None else -1, name=name, use_peephole=use_peephole) @@ -233,7 +233,7 @@ _lstm_block_cell_grad_outputs = ["cs_prev_grad", "dicfo"] @ops.RegisterGradient("LSTMBlockCell") def _LSTMBlockCellGrad(op, *grad): """Gradient for LSTMBlockCell.""" - (x, cs_prev, h_prev, w, wci, wco, wcf, b) = op.inputs + (x, cs_prev, h_prev, w, wci, wcf, wco, b) = op.inputs (i, cs, f, o, ci, co, _) = op.outputs (_, cs_grad, _, _, _, _, h_grad) = grad @@ -293,13 +293,13 @@ def _LSTMBlockCellGrad(op, *grad): @ops.RegisterGradient("BlockLSTM") def _BlockLSTMGrad(op, *grad): """Gradient for BlockLSTM.""" - seq_len_max, x, cs_prev, h_prev, w, wci, wco, wcf, b = op.inputs + seq_len_max, x, cs_prev, h_prev, w, wci, wcf, wco, b = op.inputs i, cs, f, o, ci, co, h = op.outputs cs_grad = grad[1] h_grad = grad[6] - (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, wcf_grad, + (x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, wco_grad, b_grad) = gen_lstm_ops.block_lstm_grad( seq_len_max, x, @@ -307,8 +307,8 @@ def _BlockLSTMGrad(op, *grad): h_prev, w, wci, - wco, wcf, + wco, b, i, cs, @@ -321,8 +321,10 @@ def _BlockLSTMGrad(op, *grad): h_grad, use_peephole=op.get_attr("use_peephole")) - return [None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wco_grad, - wcf_grad, b_grad] + return [ + None, x_grad, cs_prev_grad, h_prev_grad, w_grad, wci_grad, wcf_grad, + wco_grad, b_grad + ] class LSTMBlockCell(rnn_cell_impl.RNNCell): @@ -341,7 +343,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): def __init__(self, num_units, forget_bias=1.0, - clip_cell=True, + cell_clip=None, use_peephole=False, reuse=None): """Initialize the basic LSTM cell. @@ -349,8 +351,7 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). - clip_cell: boolean, whether to apply cell clipping. See - `_lstm_block_cell()` for details. + cell_clip: An optional `float`. Defaults to `-1` (no clipping). use_peephole: Whether to use peephole connections or not. reuse: (optional) boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the @@ -363,13 +364,13 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): self._num_units = num_units self._forget_bias = forget_bias self._use_peephole = use_peephole - self._clip_cell = clip_cell + self._cell_clip = cell_clip if cell_clip is not None else -1 self._names = { "W": "kernel", "b": "bias", "wci": "w_i_diag", - "wco": "w_o_diag", "wcf": "w_f_diag", + "wco": "w_o_diag", "scope": "lstm_cell" } @@ -397,10 +398,10 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): initializer=init_ops.constant_initializer(0.0)) if self._use_peephole: wci = vs.get_variable(self._names["wci"], [self._num_units]) - wco = vs.get_variable(self._names["wco"], [self._num_units]) wcf = vs.get_variable(self._names["wcf"], [self._num_units]) + wco = vs.get_variable(self._names["wco"], [self._num_units]) else: - wci = wco = wcf = array_ops.zeros([self._num_units]) + wci = wcf = wco = array_ops.zeros([self._num_units]) (cs_prev, h_prev) = states_prev (_, cs, _, _, _, _, h) = _lstm_block_cell( x, @@ -409,10 +410,10 @@ class LSTMBlockCell(rnn_cell_impl.RNNCell): w, b, wci=wci, - wco=wco, wcf=wcf, + wco=wco, forget_bias=self._forget_bias, - cell_clip=None if self._clip_cell else -1, + cell_clip=self._cell_clip, use_peephole=self._use_peephole) new_state = rnn_cell_impl.LSTMStateTuple(cs, h) @@ -594,12 +595,12 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). - cell_clip: clip the cell to this value. Defaults to `3`. + cell_clip: clip the cell to this value. Default is no cell clipping. use_peephole: Whether to use peephole connections or not. """ self._num_units = num_units self._forget_bias = forget_bias - self._cell_clip = cell_clip + self._cell_clip = cell_clip if cell_clip is not None else -1 self._use_peephole = use_peephole @property @@ -645,10 +646,10 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): dtype=dtype) if self._use_peephole: wci = vs.get_variable("w_i_diag", [self._num_units], dtype=dtype) - wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) wcf = vs.get_variable("w_f_diag", [self._num_units], dtype=dtype) + wco = vs.get_variable("w_o_diag", [self._num_units], dtype=dtype) else: - wci = wco = wcf = array_ops.zeros([self._num_units], dtype=dtype) + wci = wcf = wco = array_ops.zeros([self._num_units], dtype=dtype) if sequence_length is None: max_seq_len = math_ops.to_int64(time_len) @@ -662,8 +663,8 @@ class LSTMBlockFusedCell(LSTMBlockWrapper): h_prev=initial_output, w=w, wci=wci, - wco=wco, wcf=wcf, + wco=wco, b=b, forget_bias=self._forget_bias, cell_clip=self._cell_clip, diff --git a/tensorflow/contrib/rnn/python/ops/rnn_cell.py b/tensorflow/contrib/rnn/python/ops/rnn_cell.py index 7b28222257f27b7e95f4215f5331eb475110dbb2..5e85c125df8ca0d632fa9b0db86d942bb354631e 100644 --- a/tensorflow/contrib/rnn/python/ops/rnn_cell.py +++ b/tensorflow/contrib/rnn/python/ops/rnn_cell.py @@ -23,6 +23,7 @@ import math from tensorflow.contrib.compiler import jit from tensorflow.contrib.layers.python.layers import layers +from tensorflow.contrib.rnn.python.ops import core_rnn_cell from tensorflow.python.framework import dtypes from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops @@ -35,6 +36,7 @@ from tensorflow.python.ops import nn_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import rnn_cell_impl from tensorflow.python.ops import variable_scope as vs +from tensorflow.python.ops import partitioned_variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util import nest @@ -75,6 +77,18 @@ def _get_sharded_variable(name, shape, dtype, num_shards): return shards +def _norm(g, b, inp, scope): + shape = inp.get_shape()[-1:] + gamma_init = init_ops.constant_initializer(g) + beta_init = init_ops.constant_initializer(b) + with vs.variable_scope(scope): + # Initialize beta and gamma for use by layer_norm. + vs.get_variable("gamma", shape=shape, initializer=gamma_init) + vs.get_variable("beta", shape=shape, initializer=beta_init) + normalized = layers.layer_norm(inp, reuse=True, scope=scope) + return normalized + + class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): """Long short-term memory unit (LSTM) recurrent network cell. @@ -101,13 +115,24 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): The class uses optional peep-hole connections, and an optional projection layer. + + Layer normalization implementation is based on: + + https://arxiv.org/abs/1607.06450. + + "Layer Normalization" + Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + and is applied before the internal nonlinearities. + """ def __init__(self, num_units, use_peepholes=False, initializer=None, num_proj=None, proj_clip=None, num_unit_shards=1, num_proj_shards=1, forget_bias=1.0, state_is_tuple=True, - activation=math_ops.tanh, reuse=None): + activation=math_ops.tanh, reuse=None, + layer_norm=False, norm_gain=1.0, norm_shift=0.0): """Initialize the parameters for an LSTM cell. Args: @@ -134,6 +159,13 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. + layer_norm: If `True`, layer normalization will be applied. + norm_gain: float, The layer normalization gain initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + norm_shift: float, The layer normalization shift initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + + """ super(CoupledInputForgetGateLSTMCell, self).__init__(_reuse=reuse) if not state_is_tuple: @@ -151,6 +183,9 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): self._state_is_tuple = state_is_tuple self._activation = activation self._reuse = reuse + self._layer_norm = layer_norm + self._norm_gain = norm_gain + self._norm_shift = norm_shift if num_proj: self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj) @@ -219,9 +254,20 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): # j = new_input, f = forget_gate, o = output_gate cell_inputs = array_ops.concat([inputs, m_prev], 1) - lstm_matrix = nn_ops.bias_add(math_ops.matmul(cell_inputs, concat_w), b) + lstm_matrix = math_ops.matmul(cell_inputs, concat_w) + + # If layer nomalization is applied, do not add bias + if not self._layer_norm: + lstm_matrix = nn_ops.bias_add(lstm_matrix, b) + j, f, o = array_ops.split(value=lstm_matrix, num_or_size_splits=3, axis=1) + # Apply layer normalization + if self._layer_norm: + j = _norm(self._norm_gain, self._norm_shift, j, "transform") + f = _norm(self._norm_gain, self._norm_shift, f, "forget") + o = _norm(self._norm_gain, self._norm_shift, o, "output") + # Diagonal connections if self._use_peepholes: w_f_diag = vs.get_variable( @@ -235,6 +281,10 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell): f_act = sigmoid(f + self._forget_bias) c = (f_act * c_prev + (1 - f_act) * self._activation(j)) + # Apply layer normalization + if self._layer_norm: + c = _norm(self._norm_gain, self._norm_shift, c, "state") + if self._use_peepholes: m = sigmoid(o + w_o_diag * c) * self._activation(c) else: @@ -525,7 +575,7 @@ class GridLSTMCell(rnn_cell_impl.RNNCell): self._state_tuple_type = collections.namedtuple( "GridLSTMStateTuple", state_names.strip(",")) self._state_size = self._state_tuple_type( - *([num_units, num_units] * self._total_blocks)) + *([num_units, num_units] * self._total_blocks)) else: self._state_tuple_type = None self._state_size = num_units * self._total_blocks * 2 @@ -1017,7 +1067,7 @@ class BidirectionalGridLSTMCell(GridLSTMCell): # pylint: disable=protected-access -_linear = rnn_cell_impl._linear +_Linear = core_rnn_cell._Linear # pylint: disable=invalid-name # pylint: enable=protected-access @@ -1079,6 +1129,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): self._attn_size = attn_size self._attn_length = attn_length self._reuse = reuse + self._linear1 = None + self._linear2 = None + self._linear3 = None @property def state_size(self): @@ -1110,7 +1163,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): input_size = self._input_size if input_size is None: input_size = inputs.get_shape().as_list()[1] - inputs = _linear([inputs, attns], input_size, True) + if self._linear1 is None: + self._linear1 = _Linear([inputs, attns], input_size, True) + inputs = self._linear1([inputs, attns]) cell_output, new_state = self._cell(inputs, state) if self._state_is_tuple: new_state_cat = array_ops.concat(nest.flatten(new_state), 1) @@ -1118,7 +1173,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): new_state_cat = new_state new_attns, new_attn_states = self._attention(new_state_cat, attn_states) with vs.variable_scope("attn_output_projection"): - output = _linear([cell_output, new_attns], self._attn_size, True) + if self._linear2 is None: + self._linear2 = _Linear([cell_output, new_attns], self._attn_size, True) + output = self._linear2([cell_output, new_attns]) new_attn_states = array_ops.concat( [new_attn_states, array_ops.expand_dims(output, 1)], 1) new_attn_states = array_ops.reshape( @@ -1141,7 +1198,9 @@ class AttentionCellWrapper(rnn_cell_impl.RNNCell): hidden = array_ops.reshape(attn_states, [-1, self._attn_length, 1, self._attn_size]) hidden_features = conv2d(hidden, k, [1, 1, 1, 1], "SAME") - y = _linear(query, self._attn_vec_size, True) + if self._linear3 is None: + self._linear3 = _Linear(query, self._attn_vec_size, True) + y = self._linear3(query) y = array_ops.reshape(y, [-1, 1, 1, self._attn_vec_size]) s = reduce_sum(v * tanh(hidden_features + y), [2, 3]) a = softmax(s) @@ -1291,8 +1350,8 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): self._keep_prob = dropout_keep_prob self._seed = dropout_prob_seed self._layer_norm = layer_norm - self._g = norm_gain - self._b = norm_shift + self._norm_gain = norm_gain + self._norm_shift = norm_shift self._reuse = reuse @property @@ -1303,24 +1362,25 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): def output_size(self): return self._num_units - def _norm(self, inp, scope): + def _norm(self, inp, scope, dtype=dtypes.float32): shape = inp.get_shape()[-1:] - gamma_init = init_ops.constant_initializer(self._g) - beta_init = init_ops.constant_initializer(self._b) + gamma_init = init_ops.constant_initializer(self._norm_gain) + beta_init = init_ops.constant_initializer(self._norm_shift) with vs.variable_scope(scope): # Initialize beta and gamma for use by layer_norm. - vs.get_variable("gamma", shape=shape, initializer=gamma_init) - vs.get_variable("beta", shape=shape, initializer=beta_init) + vs.get_variable("gamma", shape=shape, initializer=gamma_init, dtype=dtype) + vs.get_variable("beta", shape=shape, initializer=beta_init, dtype=dtype) normalized = layers.layer_norm(inp, reuse=True, scope=scope) return normalized def _linear(self, args): out_size = 4 * self._num_units proj_size = args.get_shape()[-1] - weights = vs.get_variable("kernel", [proj_size, out_size]) + dtype = args.dtype + weights = vs.get_variable("kernel", [proj_size, out_size], dtype=dtype) out = math_ops.matmul(args, weights) if not self._layer_norm: - bias = vs.get_variable("bias", [out_size]) + bias = vs.get_variable("bias", [out_size], dtype=dtype) out = nn_ops.bias_add(out, bias) return out @@ -1329,13 +1389,14 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): c, h = state args = array_ops.concat([inputs, h], 1) concat = self._linear(args) + dtype = args.dtype i, j, f, o = array_ops.split(value=concat, num_or_size_splits=4, axis=1) if self._layer_norm: - i = self._norm(i, "input") - j = self._norm(j, "transform") - f = self._norm(f, "forget") - o = self._norm(o, "output") + i = self._norm(i, "input", dtype=dtype) + j = self._norm(j, "transform", dtype=dtype) + f = self._norm(f, "forget", dtype=dtype) + o = self._norm(o, "output", dtype=dtype) g = self._activation(j) if (not isinstance(self._keep_prob, float)) or self._keep_prob < 1: @@ -1344,7 +1405,7 @@ class LayerNormBasicLSTMCell(rnn_cell_impl.RNNCell): new_c = (c * math_ops.sigmoid(f + self._forget_bias) + math_ops.sigmoid(i) * g) if self._layer_norm: - new_c = self._norm(new_c, "state") + new_c = self._norm(new_c, "state", dtype=dtype) new_h = self._activation(new_c) * math_ops.sigmoid(o) new_state = rnn_cell_impl.LSTMStateTuple(new_c, new_h) @@ -1537,6 +1598,7 @@ class UGRNNCell(rnn_cell_impl.RNNCell): self._forget_bias = forget_bias self._activation = activation self._reuse = reuse + self._linear = None @property def state_size(self): @@ -1573,7 +1635,9 @@ class UGRNNCell(rnn_cell_impl.RNNCell): with vs.variable_scope(vs.get_variable_scope(), initializer=self._initializer): cell_inputs = array_ops.concat([inputs, state], 1) - rnn_matrix = _linear(cell_inputs, 2 * self._num_units, True) + if self._linear is None: + self._linear = _Linear(cell_inputs, 2 * self._num_units, True) + rnn_matrix = self._linear(cell_inputs) [g_act, c_act] = array_ops.split( axis=1, num_or_size_splits=2, value=rnn_matrix) @@ -1638,6 +1702,8 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): self._num_input_proj = num_in_proj self._y_activation = y_activation self._reuse = reuse + self._linear1 = None + self._linear2 = None @property def state_size(self): @@ -1680,7 +1746,9 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): if input_size.value != self._num_units: if self._num_input_proj: with vs.variable_scope("in_projection"): - inputs = _linear(inputs, self._num_units, True) + if self._linear1 is None: + self._linear1 = _Linear(inputs, self._num_units, True) + inputs = self._linear1(inputs) else: raise ValueError("Must have input size == output size for " "Intersection RNN. To fix, num_in_proj should " @@ -1688,7 +1756,9 @@ class IntersectionRNNCell(rnn_cell_impl.RNNCell): n_dim = i_dim = self._num_units cell_inputs = array_ops.concat([inputs, state], 1) - rnn_matrix = _linear(cell_inputs, 2*n_dim + 2*i_dim, True) + if self._linear2 is None: + self._linear2 = _Linear(cell_inputs, 2*n_dim + 2*i_dim, True) + rnn_matrix = self._linear2(cell_inputs) gh_act = rnn_matrix[:, :n_dim] # b x n h_act = rnn_matrix[:, n_dim:2*n_dim] # b x n @@ -1825,6 +1895,9 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): self._period_init_min = period_init_min self._period_init_max = period_init_max self._reuse = reuse + self._linear1 = None + self._linear2 = None + self._linear3 = None @property def state_size(self): @@ -1872,14 +1945,18 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): in_mask_gates.append(c_prev) with vs.variable_scope("mask_gates"): + if self._linear1 is None: + self._linear1 = _Linear(in_mask_gates, 2 * self._num_units, True) + mask_gates = math_ops.sigmoid( - _linear(in_mask_gates, 2 * self._num_units, True)) + self._linear1(in_mask_gates)) [input_gate, forget_gate] = array_ops.split( axis=1, num_or_size_splits=2, value=mask_gates) with vs.variable_scope("new_input"): - new_input = math_ops.tanh( - _linear([x, h_prev], self._num_units, True)) + if self._linear2 is None: + self._linear2 = _Linear([x, h_prev], self._num_units, True) + new_input = math_ops.tanh(self._linear2([x, h_prev])) new_c = (c_prev * forget_gate + input_gate * new_input) @@ -1888,8 +1965,9 @@ class PhasedLSTMCell(rnn_cell_impl.RNNCell): in_out_gate.append(new_c) with vs.variable_scope("output_gate"): - output_gate = math_ops.sigmoid( - _linear(in_out_gate, self._num_units, True)) + if self._linear3 is None: + self._linear3 = _Linear(in_out_gate, self._num_units, True) + output_gate = math_ops.sigmoid(self._linear3(in_out_gate)) new_h = math_ops.tanh(new_c) * output_gate @@ -2056,9 +2134,11 @@ def _conv(args, shape_length = len(shapes[0]) for shape in shapes: if len(shape) not in [3,4,5]: - raise ValueError("Conv Linear expects 3D, 4D or 5D arguments: %s" % str(shapes)) + raise ValueError("Conv Linear expects 3D, 4D " + "or 5D arguments: %s" % str(shapes)) if len(shape) != len(shapes[0]): - raise ValueError("Conv Linear expects all args to be of same Dimensiton: %s" % str(shapes)) + raise ValueError("Conv Linear expects all args " + "to be of same Dimension: %s" % str(shapes)) else: total_arg_size_depth += shape[-1] dtype = [a.dtype for a in args][0] @@ -2076,7 +2156,7 @@ def _conv(args, # Now the computation. kernel = vs.get_variable( - "kernel", + "kernel", filter_size + [total_arg_size_depth, num_features], dtype=dtype) if len(args) == 1: @@ -2159,6 +2239,8 @@ class GLSTMCell(rnn_cell_impl.RNNCell): else: self._state_size = rnn_cell_impl.LSTMStateTuple(num_units, num_units) self._output_size = num_units + self._linear1 = None + self._linear2 = None @property def state_size(self): @@ -2227,7 +2309,9 @@ class GLSTMCell(rnn_cell_impl.RNNCell): self._group_shape[0]), self._get_input_for_group(m_prev, group_id, self._group_shape[0])], axis=1) - R_k = _linear(x_g_id, 4 * self._group_shape[1], bias=False) + if self._linear1 is None: + self._linear1 = _Linear(x_g_id, 4 * self._group_shape[1], False) + R_k = self._linear1(x_g_id) # pylint: disable=invalid-name i_k, j_k, f_k, o_k = array_ops.split(R_k, 4, 1) i_parts.append(i_k) @@ -2267,7 +2351,270 @@ class GLSTMCell(rnn_cell_impl.RNNCell): if self._num_proj is not None: with vs.variable_scope("projection"): - m = _linear(m, self._num_proj, bias=False) + if self._linear2 is None: + self._linear2 = _Linear(m, self._num_proj, False) + m = self._linear2(m) new_state = rnn_cell_impl.LSTMStateTuple(c, m) return m, new_state + + +class LayerNormLSTMCell(rnn_cell_impl.RNNCell): + """Long short-term memory unit (LSTM) recurrent network cell. + + The default non-peephole implementation is based on: + + http://www.bioinf.jku.at/publications/older/2604.pdf + + S. Hochreiter and J. Schmidhuber. + "Long Short-Term Memory". Neural Computation, 9(8):1735-1780, 1997. + + The peephole implementation is based on: + + https://research.google.com/pubs/archive/43905.pdf + + Hasim Sak, Andrew Senior, and Francoise Beaufays. + "Long short-term memory recurrent neural network architectures for + large scale acoustic modeling." INTERSPEECH, 2014. + + The class uses optional peep-hole connections, optional cell clipping, and + an optional projection layer. + + Layer normalization implementation is based on: + + https://arxiv.org/abs/1607.06450. + + "Layer Normalization" + Jimmy Lei Ba, Jamie Ryan Kiros, Geoffrey E. Hinton + + and is applied before the internal nonlinearities. + + """ + + def __init__(self, num_units, + use_peepholes=False, cell_clip=None, + initializer=None, num_proj=None, proj_clip=None, + forget_bias=1.0, + activation=None, layer_norm=False, + norm_gain=1.0, norm_shift=0.0, reuse=None): + """Initialize the parameters for an LSTM cell. + + Args: + num_units: int, The number of units in the LSTM cell + use_peepholes: bool, set True to enable diagonal/peephole connections. + cell_clip: (optional) A float value, if provided the cell state is clipped + by this value prior to the cell output activation. + initializer: (optional) The initializer to use for the weight and + projection matrices. + num_proj: (optional) int, The output dimensionality for the projection + matrices. If None, no projection is performed. + proj_clip: (optional) A float value. If `num_proj > 0` and `proj_clip` is + provided, then the projected values are clipped elementwise to within + `[-proj_clip, proj_clip]`. + forget_bias: Biases of the forget gate are initialized by default to 1 + in order to reduce the scale of forgetting at the beginning of + the training. Must set it manually to `0.0` when restoring from + CudnnLSTM trained checkpoints. + activation: Activation function of the inner states. Default: `tanh`. + layer_norm: If `True`, layer normalization will be applied. + norm_gain: float, The layer normalization gain initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + norm_shift: float, The layer normalization shift initial value. If + `layer_norm` has been set to `False`, this argument will be ignored. + reuse: (optional) Python boolean describing whether to reuse variables + in an existing scope. If not `True`, and the existing scope already has + the given variables, an error is raised. + + When restoring from CudnnLSTM-trained checkpoints, must use + CudnnCompatibleLSTMCell instead. + """ + super(LayerNormLSTMCell, self).__init__(_reuse=reuse) + + self._num_units = num_units + self._use_peepholes = use_peepholes + self._cell_clip = cell_clip + self._initializer = initializer + self._num_proj = num_proj + self._proj_clip = proj_clip + self._forget_bias = forget_bias + self._activation = activation or math_ops.tanh + self._layer_norm = layer_norm + self._norm_gain = norm_gain + self._norm_shift = norm_shift + + if num_proj: + self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_proj)) + self._output_size = num_proj + else: + self._state_size = (rnn_cell_impl.LSTMStateTuple(num_units, num_units)) + self._output_size = num_units + + @property + def state_size(self): + return self._state_size + + @property + def output_size(self): + return self._output_size + + + def _linear(self, + args, + output_size, + bias, + bias_initializer=None, + kernel_initializer=None, + layer_norm=False): + """Linear map: sum_i(args[i] * W[i]), where W[i] is a Variable. + + Args: + args: a 2D Tensor or a list of 2D, batch x n, Tensors. + output_size: int, second dimension of W[i]. + bias: boolean, whether to add a bias term or not. + bias_initializer: starting value to initialize the bias + (default is all zeros). + kernel_initializer: starting value to initialize the weight. + layer_norm: boolean, whether to apply layer normalization. + + + Returns: + A 2D Tensor with shape [batch x output_size] taking value + sum_i(args[i] * W[i]), where each W[i] is a newly created Variable. + + Raises: + ValueError: if some of the arguments has unspecified or wrong shape. + """ + if args is None or (nest.is_sequence(args) and not args): + raise ValueError("`args` must be specified") + if not nest.is_sequence(args): + args = [args] + + # Calculate the total size of arguments on dimension 1. + total_arg_size = 0 + shapes = [a.get_shape() for a in args] + for shape in shapes: + if shape.ndims != 2: + raise ValueError("linear is expecting 2D arguments: %s" % shapes) + if shape[1].value is None: + raise ValueError("linear expects shape[1] to be provided for shape %s, " + "but saw %s" % (shape, shape[1])) + else: + total_arg_size += shape[1].value + + dtype = [a.dtype for a in args][0] + + # Now the computation. + scope = vs.get_variable_scope() + with vs.variable_scope(scope) as outer_scope: + weights = vs.get_variable( + "kernel", [total_arg_size, output_size], + dtype=dtype, + initializer=kernel_initializer) + if len(args) == 1: + res = math_ops.matmul(args[0], weights) + else: + res = math_ops.matmul(array_ops.concat(args, 1), weights) + if not bias: + return res + with vs.variable_scope(outer_scope) as inner_scope: + inner_scope.set_partitioner(None) + if bias_initializer is None: + bias_initializer = init_ops.constant_initializer(0.0, dtype=dtype) + biases = vs.get_variable( + "bias", [output_size], + dtype=dtype, + initializer=bias_initializer) + + if not layer_norm: + res = nn_ops.bias_add(res, biases) + + return res + + def call(self, inputs, state): + """Run one step of LSTM. + + Args: + inputs: input Tensor, 2D, batch x num_units. + state: this must be a tuple of state Tensors, + both `2-D`, with column sizes `c_state` and + `m_state`. + + Returns: + A tuple containing: + + - A `2-D, [batch x output_dim]`, Tensor representing the output of the + LSTM after reading `inputs` when previous state was `state`. + Here output_dim is: + num_proj if num_proj was set, + num_units otherwise. + - Tensor(s) representing the new state of LSTM after reading `inputs` when + the previous state was `state`. Same type and shape(s) as `state`. + + Raises: + ValueError: If input size cannot be inferred from inputs via + static shape inference. + """ + num_proj = self._num_units if self._num_proj is None else self._num_proj + sigmoid = math_ops.sigmoid + + (c_prev, m_prev) = state + + dtype = inputs.dtype + input_size = inputs.get_shape().with_rank(2)[1] + if input_size.value is None: + raise ValueError("Could not infer input size from inputs.get_shape()[-1]") + scope = vs.get_variable_scope() + with vs.variable_scope(scope, initializer=self._initializer) as unit_scope: + + # i = input_gate, j = new_input, f = forget_gate, o = output_gate + lstm_matrix = self._linear([inputs, m_prev], 4 * self._num_units, bias=True, + bias_initializer=None, layer_norm=self._layer_norm) + i, j, f, o = array_ops.split( + value=lstm_matrix, num_or_size_splits=4, axis=1) + + if self._layer_norm: + i = _norm(self._norm_gain, self._norm_shift, i, "input") + j = _norm(self._norm_gain, self._norm_shift, j, "transform") + f = _norm(self._norm_gain, self._norm_shift, f, "forget") + o = _norm(self._norm_gain, self._norm_shift, o, "output") + + # Diagonal connections + if self._use_peepholes: + with vs.variable_scope(unit_scope) as projection_scope: + w_f_diag = vs.get_variable( + "w_f_diag", shape=[self._num_units], dtype=dtype) + w_i_diag = vs.get_variable( + "w_i_diag", shape=[self._num_units], dtype=dtype) + w_o_diag = vs.get_variable( + "w_o_diag", shape=[self._num_units], dtype=dtype) + + if self._use_peepholes: + c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev + + sigmoid(i + w_i_diag * c_prev) * self._activation(j)) + else: + c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) * + self._activation(j)) + + if self._layer_norm: + c = _norm(self._norm_gain, self._norm_shift, c, "state") + + if self._cell_clip is not None: + # pylint: disable=invalid-unary-operand-type + c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip) + # pylint: enable=invalid-unary-operand-type + if self._use_peepholes: + m = sigmoid(o + w_o_diag * c) * self._activation(c) + else: + m = sigmoid(o) * self._activation(c) + + if self._num_proj is not None: + with vs.variable_scope("projection") as proj_scope: + m = self._linear(m, self._num_proj, bias=False) + + if self._proj_clip is not None: + # pylint: disable=invalid-unary-operand-type + m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip) + # pylint: enable=invalid-unary-operand-type + + new_state = (rnn_cell_impl.LSTMStateTuple(c, m)) + return m, new_state diff --git a/tensorflow/contrib/saved_model/BUILD b/tensorflow/contrib/saved_model/BUILD index a82ee6ac41ed3f81bd96c61dafb2144c41b07065..20be819e07d0e47a0b24b5cc2548727322093e50 100644 --- a/tensorflow/contrib/saved_model/BUILD +++ b/tensorflow/contrib/saved_model/BUILD @@ -37,9 +37,14 @@ py_library( visibility = ["//visibility:public"], deps = [ "//tensorflow/core:protos_all_py", + "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:util", + "//tensorflow/python/saved_model:builder", "//tensorflow/python/saved_model:constants", + "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -85,10 +90,11 @@ py_test( deps = [ ":saved_model_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python/saved_model", + "//tensorflow/python:variables", + "//tensorflow/python/saved_model:loader", "//tensorflow/python/saved_model:signature_constants", + "//tensorflow/python/saved_model:tag_constants", ], ) diff --git a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py index 282dd7dc3b0bad67e41cc2a02ac6fa2d2fd60397..d2e14f73e47c7c22f3a6e8ff12bfe05463755726 100644 --- a/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py +++ b/tensorflow/contrib/saved_model/python/saved_model/signature_def_utils_test.py @@ -94,7 +94,7 @@ class SignatureDefUtilsTest(test.TestCase): def testGetSignatureDefByKeyRegression(self): input1 = constant_op.constant("a", name="input-1") - output1 = constant_op.constant("b", name="output-1") + output1 = constant_op.constant(7.2, name="output-1") meta_graph_def = meta_graph_pb2.MetaGraphDef() self._add_to_signature_def_map(meta_graph_def, { @@ -123,7 +123,7 @@ class SignatureDefUtilsTest(test.TestCase): def testGetSignatureDefByKeyClassification(self): input1 = constant_op.constant("a", name="input-1") output1 = constant_op.constant("b", name="output-1") - output2 = constant_op.constant("c", name="output-2") + output2 = constant_op.constant(3.0, name="output-2") meta_graph_def = meta_graph_pb2.MetaGraphDef() self._add_to_signature_def_map(meta_graph_def, { diff --git a/tensorflow/contrib/seq2seq/BUILD b/tensorflow/contrib/seq2seq/BUILD index f1e39a137322711efacda02abd3c13f528981bc1..ab80c68b1a8e4ff151494e393b68c460846fa8fe 100644 --- a/tensorflow/contrib/seq2seq/BUILD +++ b/tensorflow/contrib/seq2seq/BUILD @@ -33,18 +33,31 @@ tf_custom_op_py_library( "//tensorflow/contrib/distributions:distributions_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/rnn:rnn_py", + "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:clip_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:embedding_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:functional_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:layers", + "//tensorflow/python:layers_base", "//tensorflow/python:math_ops", + "//tensorflow/python:nn_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_ops", "//tensorflow/python:rnn", "//tensorflow/python:rnn_cell", "//tensorflow/python:script_ops", "//tensorflow/python:tensor_array_ops", + "//tensorflow/python:tensor_util", "//tensorflow/python:util", "//tensorflow/python:variable_scope", + "//tensorflow/python/ops/distributions", "//third_party/py/numpy", + "@six_archive//:six", ], ) diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc index aab0f3f4947388741765b268094b4136d356a457..64973ccccdc962757a727d7183bd70e94edcfd1b 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.cc @@ -49,40 +49,46 @@ class GatherTreeOp : public OpKernel { const Device& device = ctx->eigen_device(); const Tensor& step_ids = ctx->input(0); const Tensor& parent_ids = ctx->input(1); - const Tensor& sequence_length = ctx->input(2); + const Tensor& max_sequence_lengths = ctx->input(2); + const Tensor& end_token = ctx->input(3); const TensorShape& step_ids_shape = step_ids.shape(); OP_REQUIRES( ctx, step_ids_shape.dims() == 3, errors::InvalidArgument("step_ids must be a 3-tensor, saw shape: ", step_ids_shape.DebugString())); - OP_REQUIRES( - ctx, TensorShapeUtils::IsMatrix(sequence_length.shape()), - errors::InvalidArgument("sequence_length must be a matrix, saw shape: ", - sequence_length.shape().DebugString())); - OP_REQUIRES(ctx, sequence_length.dim_size(0) == step_ids_shape.dim_size(1), - errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[0] (", - sequence_length.dim_size(0), ") != ", "step_ids.shape[1] (", - step_ids_shape.dim_size(1), ")")); - OP_REQUIRES(ctx, sequence_length.dim_size(1) == step_ids_shape.dim_size(2), + OP_REQUIRES(ctx, TensorShapeUtils::IsVector(max_sequence_lengths.shape()), errors::InvalidArgument( - "Inconsistent batch sizes: sequence_length.shape[1] (", - sequence_length.dim_size(1), ") != ", "step_ids.shape[2] (", - step_ids_shape.dim_size(2), ")")); + "max_sequence_lengths must be a vector, saw shape: ", + max_sequence_lengths.shape().DebugString())); + OP_REQUIRES( + ctx, TensorShapeUtils::IsScalar(end_token.shape()), + errors::InvalidArgument("end_token must be a scalar, saw shape: ", + end_token.shape().DebugString())); OP_REQUIRES( ctx, step_ids_shape == parent_ids.shape(), errors::InvalidArgument( "step_ids.shape must match parent_ids.shape. but shapes are: ", step_ids_shape.DebugString(), " and ", parent_ids.shape().DebugString())); + OP_REQUIRES( + ctx, + step_ids_shape.dim_size(1) == max_sequence_lengths.shape().dim_size(0), + errors::InvalidArgument("batch size dimensions step_ids.shape[1] and " + "max_seqeuence_lengths.shape[0] must match. " + "but shapes are: ", + step_ids_shape.DebugString(), " and ", + max_sequence_lengths.shape().DebugString())); Tensor* beams; OP_REQUIRES_OK(ctx, ctx->allocate_output(0, step_ids_shape, &beams)); typename TTypes::ConstTensor step_ids_t = step_ids.tensor(); typename TTypes::ConstTensor parent_ids_t = parent_ids.tensor(); - typename TTypes::ConstMatrix seq_len_t = sequence_length.matrix(); + typename TTypes::ConstVec max_seq_lens_t = + max_sequence_lengths.vec(); + typename TTypes::ConstScalar end_token_t = end_token.scalar(); typename TTypes::Tensor beams_t = beams->tensor(); + const T end_token_value = end_token_t(); functor::GatherTree()(ctx, device, step_ids_t, parent_ids_t, - seq_len_t, beams_t); + max_seq_lens_t, end_token_value, beams_t); } }; @@ -99,27 +105,29 @@ namespace functor { template <> struct GatherTree { void operator()(OpKernelContext* ctx, const CPUDevice& d, - typename TTypes::ConstTensor step_ids, - typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams) { - const int64 max_time = parent_ids.dimension(0); - const int64 batch_size = parent_ids.dimension(1); - const int64 beam_width = parent_ids.dimension(2); - beams.setConstant(-1); - - auto DoWork = [&, ctx](int start_batch_beam, int limit_batch_beam) { + TTypes::ConstTensor step_ids, + TTypes::ConstTensor parent_ids, + TTypes::ConstVec max_sequence_lengths, + const int32 end_token, TTypes::Tensor beams) { + const int32 max_time = parent_ids.dimension(0); + const int32 batch_size = parent_ids.dimension(1); + const int32 beam_width = parent_ids.dimension(2); + beams.setConstant(end_token); + + auto DoWork = [&, ctx, end_token](int start_batch_beam, + int limit_batch_beam) { for (int32 i = start_batch_beam; i < limit_batch_beam; ++i) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - int32 seq_len_b = sequence_length(batch, beam); - if (seq_len_b <= 0) { + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, max_sequence_lengths(batch)); + if (max_seq_len_b <= 0) { continue; } - beams(seq_len_b - 1, batch, beam) = - step_ids(seq_len_b - 1, batch, beam); - int32 parent = parent_ids(seq_len_b - 1, batch, beam); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + beams(max_seq_len_b - 1, batch, beam) = + step_ids(max_seq_len_b - 1, batch, beam); + int32 parent = parent_ids(max_seq_len_b - 1, batch, beam); + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { if (parent < 0 || parent > beam_width) { ctx->SetStatus( errors::InvalidArgument("Saw invalid parent id ", parent, @@ -130,6 +138,17 @@ struct GatherTree { beams(level, batch, beam) = step_ids(level, batch, parent); parent = parent_ids(level, batch, parent); } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + if (finished) { + beams(time, batch, beam) = end_token; + } else if (beams(time, batch, beam) == end_token) { + finished = true; + } + } } }; // Guesstimate of cost; ~5 lookup/store/compare per inner beam @@ -137,7 +156,7 @@ struct GatherTree { const int64 batch_beam_cost = Eigen::TensorOpCost::DivCost() + 6 * Eigen::TensorOpCost::AddCost() + - max_time * (5 * Eigen::TensorOpCost::AddCost()); + 2 * max_time * (5 * Eigen::TensorOpCost::AddCost()); auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads()); Shard(worker_threads.num_threads, worker_threads.workers, batch_size * beam_width, batch_beam_cost, DoWork); @@ -148,24 +167,26 @@ struct GatherTree { #if GOOGLE_CUDA namespace functor { -#define DECLARE_GPU_SPEC(T) \ - template <> \ - void GatherTree::operator()( \ - OpKernelContext* ctx, const GPUDevice& d, \ - typename TTypes::ConstTensor step_ids, \ - typename TTypes::ConstTensor parent_ids, \ - typename TTypes::ConstMatrix sequence_length, \ - typename TTypes::Tensor beams); \ +#define DECLARE_GPU_SPEC(T) \ + template <> \ + void GatherTree::operator()( \ + OpKernelContext* ctx, const GPUDevice& d, \ + typename TTypes::ConstTensor step_ids, \ + typename TTypes::ConstTensor parent_ids, \ + TTypes::ConstVec max_sequence_lengths, const T end_token, \ + typename TTypes::Tensor beams); \ extern template struct GatherTree; DECLARE_GPU_SPEC(int32); #undef DECLARE_GPU_SPEC } // end namespace functor -#define REGISTER_GPU_KERNEL(T) \ - REGISTER_KERNEL_BUILDER( \ - Name("GatherTree").Device(DEVICE_GPU).TypeConstraint("T"), \ - GatherTreeOp); +#define REGISTER_GPU_KERNEL(T) \ + REGISTER_KERNEL_BUILDER(Name("GatherTree") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("T") \ + .HostMemory("end_token"), \ + GatherTreeOp); REGISTER_GPU_KERNEL(int32); #undef REGISTER_GPU_KERNEL diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h index 124d07264e75ac4ce7739dd3291abdabbb40a58f..693b02dc437afdf14c38e4224c5469bb3e569540 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops.h @@ -31,8 +31,8 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const Device& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams); + TTypes::ConstVec max_sequence_lengths, + const T end_token, typename TTypes::Tensor beams); }; } // namespace functor diff --git a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc index ee68b55d20214c207597750e083a63e94ebdc0a0..bc28d492fe1a25afe0d0783539aa9e759e7b703f 100644 --- a/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc +++ b/tensorflow/contrib/seq2seq/kernels/beam_search_ops_gpu.cu.cc @@ -29,30 +29,50 @@ template __global__ void GatherTreeOpKernel(const int32 batch_size, const int32 max_time, const int32 beam_width, const T* step_ids, const T* parent_ids, - const T* sequence_length, T* beams) { + const int32* max_sequence_lengths, + const T end_token, T* beams) { CUDA_1D_KERNEL_LOOP(i, batch_size * beam_width) { const int32 batch = i / beam_width; const int32 beam = i % beam_width; - const int32 seq_len_b = ldg(sequence_length + batch * beam_width + beam); - if (seq_len_b <= 0) continue; + const int32 max_seq_len_b = + Eigen::numext::mini(max_time, ldg(max_sequence_lengths + batch)); + if (max_seq_len_b <= 0) { + continue; + } #define GET_IX(time_ix, beam_ix) \ (batch_size * beam_width * (time_ix) + beam_width * batch + (beam_ix)) - const int32 initial_beam_ix = GET_IX(seq_len_b - 1, beam); + const int32 initial_beam_ix = GET_IX(max_seq_len_b - 1, beam); beams[initial_beam_ix] = ldg(step_ids + initial_beam_ix); int32 parent = ldg(parent_ids + initial_beam_ix); - for (int32 level = seq_len_b - 2; level >= 0; --level) { + bool found_bad = false; + for (int32 level = max_seq_len_b - 2; level >= 0; --level) { const int32 level_beam_ix = GET_IX(level, beam); const int32 level_parent_ix = GET_IX(level, parent); if (parent < 0 || parent > beam_width) { beams[level_beam_ix] = -1; parent = -1; + found_bad = true; } else { beams[level_beam_ix] = ldg(step_ids + level_parent_ix); parent = ldg(parent_ids + level_parent_ix); } } + // Not necessary when using a BeamSearchDecoder, but necessary + // when a user feeds in possibly broken trajectory (i.e., non-eos + // entries in a beam following eos entries). + if (!found_bad) { + bool finished = false; + for (int32 time = 0; time < max_seq_len_b; ++time) { + const int32 level_beam_ix = GET_IX(time, beam); + if (finished) { + beams[level_beam_ix] = end_token; + } else if (beams[level_beam_ix] == end_token) { + finished = true; + } + } + } #undef GET_IX } } @@ -62,20 +82,23 @@ struct GatherTree { void operator()(OpKernelContext* ctx, const GPUDevice& d, typename TTypes::ConstTensor step_ids, typename TTypes::ConstTensor parent_ids, - typename TTypes::ConstMatrix sequence_length, - typename TTypes::Tensor beams) { + TTypes::ConstVec max_sequence_length, + const T end_token, typename TTypes::Tensor beams) { const int32 max_time = parent_ids.dimension(0); const int32 batch_size = parent_ids.dimension(1); const int32 beam_width = parent_ids.dimension(2); - // First kernel launch to zero things out - beams.device(d) = beams.constant(T(-1)); + // First kernel launch to "zero" things out + beams.device(d) = beams.constant(end_token); CudaLaunchConfig config = GetCudaLaunchConfig(batch_size * beam_width, d); // clang-format off GatherTreeOpKernel <<>>( batch_size, max_time, beam_width, - step_ids.data(), parent_ids.data(), sequence_length.data(), + step_ids.data(), + parent_ids.data(), + max_sequence_length.data(), + end_token, beams.data()); // clang-format on } diff --git a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc index 6c445cd4606381ed56d91000bc5e42d874ca0c5c..71539b6f592f0c8e53c4bb3801d1e35f34814966 100644 --- a/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc +++ b/tensorflow/contrib/seq2seq/ops/beam_search_ops.cc @@ -25,27 +25,27 @@ using shape_inference::ShapeHandle; REGISTER_OP("GatherTree") .Input("step_ids: T") .Input("parent_ids: T") - .Input("sequence_length: T") + .Input("max_sequence_lengths: int32") + .Input("end_token: T") .Output("beams: T") .Attr("T: {int32}") .SetShapeFn([](InferenceContext* c) { - ShapeHandle step_ids, parent_ids, sequence_length; + ShapeHandle step_ids, parent_ids, max_sequence_lengths, end_token; // step_ids, parent_ids, and output are all shaped: // [max_time, batch_size, beam_width]. - // sequence_length is shaped [batch_size, beam_width]. + // max_sequence_length is shaped [batch_size] and end_token is a scalar. TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &step_ids)); TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &parent_ids)); - TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &sequence_length)); - - DimensionHandle batch_size = c->Dim(step_ids, 1); - DimensionHandle beam_width = c->Dim(step_ids, 2); - + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &max_sequence_lengths)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &end_token)); TF_RETURN_IF_ERROR(c->Merge(step_ids, parent_ids, &step_ids)); + DimensionHandle batch_size = c->Dim(step_ids, 1); TF_RETURN_IF_ERROR( - c->Merge(batch_size, c->Dim(sequence_length, 0), &batch_size)); - TF_RETURN_IF_ERROR( - c->Merge(beam_width, c->Dim(sequence_length, 1), &beam_width)); + c->Merge(batch_size, c->Dim(max_sequence_lengths, 0), &batch_size)); + ShapeHandle step_ids_prefix = c->Matrix(c->Dim(step_ids, 0), batch_size); + TF_RETURN_IF_ERROR(c->MergePrefix(step_ids, step_ids_prefix, &step_ids, + &step_ids_prefix)); c->set_output(0, step_ids); return tensorflow::Status::OK(); @@ -53,15 +53,19 @@ REGISTER_OP("GatherTree") .Doc(R"doc( Calculates the full beams from the per-step ids and parent beam ids. -This op implements the following mathematical equations: +On CPU, if an out of bound parent id is found, an error is returned. +On GPU, if an out of bound parent id is found, a -1 is stored in the +corresponding output value and the execution for that beam returns early. + +For a given beam, past the time step containing the first decoded `end_token` +all values are filled in with `end_token`. -```python -TODO(ebrevdo): fill in -``` +TODO(ebrevdo): fill in the remainder of this docstring. step_ids: `[max_time, batch_size, beam_width]`. parent_ids: `[max_time, batch_size, beam_width]`. -sequence_length: `[batch_size, beam_width]`. +max_sequence_lengths: `[batch_size]`. +end_token: `[]`. beams: `[max_time, batch_size, beam_width]`. )doc"); diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py index 2caeb9eb614382c815984391df87a70516f519b2..d2beac5f31460ec1c0d978a9f6fcd0e0f09cb9b4 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_decoder_test.py @@ -54,15 +54,18 @@ class TestGatherTree(test.TestCase): [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], dtype=np.int32).transpose([1, 0, 2]) - # sequence_lengths is shaped (batch_size = 2, beam_width = 3) - sequence_lengths = [[3, 3, 3], [3, 3, 3]] + # sequence_lengths is shaped (batch_size = 3) + max_sequence_lengths = [3, 3] expected_result = np.array( [[[2, 2, 2], [6, 5, 6], [7, 8, 9]], [[2, 4, 4], [7, 6, 6], [8, 9, 10]]]).transpose([1, 0, 2]) res = beam_search_ops.gather_tree( - predicted_ids, parent_ids, sequence_lengths) + predicted_ids, + parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=11) with self.test_session() as sess: res_ = sess.run(res) @@ -80,8 +83,7 @@ class TestEosMasking(test.TestCase): ]) eos_token = 0 - previously_finished = constant_op.constant( - [[0, 1, 0], [0, 1, 1]], dtype=dtypes.float32) + previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) masked = beam_search_decoder._mask_probs(probs, eos_token, previously_finished) diff --git a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py index 50cccf392fdac75f551b180987aff0b31da0893e..277c5b6ef76bce8d59e47cf0026c6e2b1d5cf1e2 100644 --- a/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py +++ b/tensorflow/contrib/seq2seq/python/kernel_tests/beam_search_ops_test.py @@ -19,6 +19,8 @@ from __future__ import division from __future__ import print_function # pylint: enable=unused-import +import itertools + import numpy as np from tensorflow.contrib.seq2seq.python.ops import beam_search_ops @@ -34,31 +36,37 @@ class GatherTreeTest(test.TestCase): def testGatherTreeOne(self): # (max_time = 4, batch_size = 1, beams = 3) + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] - expected_result = _transpose_batch_time( - [[[2, 2, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + max_sequence_lengths = [3] + expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testBadParentValuesOnCPU(self): # (batch_size = 1, max_time = 4, beams = 3) # bad parent in beam 1 time 1 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] + max_sequence_lengths = [3] with ops.device("/cpu:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(): with self.assertRaisesOpError( r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): @@ -71,82 +79,63 @@ class GatherTreeTest(test.TestCase): return # (max_time = 4, batch_size = 1, beams = 3) # bad parent in beam 1 time 1; appears as a negative index at time 0 + end_token = 10 step_ids = _transpose_batch_time( [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) parent_ids = _transpose_batch_time( [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) - sequence_length = [[3, 3, 3]] - expected_result = _transpose_batch_time( - [[[2, -1, 2], [6, 5, 6], [7, 8, 9], [-1, -1, -1]]]) + max_sequence_lengths = [3] + expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], + [10, 10, 10]]]) with ops.device("/device:GPU:0"): beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) + step_ids=step_ids, + parent_ids=parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) with self.test_session(use_gpu=True): self.assertAllEqual(expected_result, beams.eval()) def testGatherTreeBatch(self): - # sequence_length is [batch_size, beam_width] = [4, 5] - sequence_length = [[0] * 5, [1] * 5, [2] * 5, [3] * 5] + batch_size = 10 + beam_width = 15 + max_time = 8 + max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] + end_token = 5 with self.test_session(use_gpu=True): - # (max_time = 4, batch_size = 4, beam_width = 5) - step_ids = _transpose_batch_time( - [[[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[3, 4, 0, 4, 0], - [4, 2, 0, 3, 1], - [1, 1, 3, 2, 2], - [3, 1, 2, 3, 4]], - [[1, 2, 3, 4, 2], - [2, 1, 1, 3, 2], - [3, 0, 1, 0, 0], - [3, 4, 0, 2, 4]], - [[0, 2, 2, 3, 1], - [3, 2, 2, 2, 3], - [3, 4, 3, 0, 3], - [1, 2, 2, 2, 4]]]) - parent_ids = _transpose_batch_time( - [[[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[4, 2, 4, 3, 4], - [3, 4, 0, 2, 0], - [3, 1, 3, 2, 2], - [0, 2, 1, 4, 2]], - [[3, 0, 0, 4, 0], - [1, 2, 4, 2, 2], - [4, 4, 0, 3, 0], - [2, 4, 4, 3, 0]], - [[3, 1, 4, 1, 3], - [3, 2, 4, 0, 4], - [1, 0, 1, 4, 2], - [0, 3, 2, 0, 1]]]) - expected_beams = _transpose_batch_time( - [[[-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[3, 4, 0, 4, 0], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 3, 3], - [2, 1, 1, 3, 2], - [-1, -1, -1, -1, -1], - [-1, -1, -1, -1, -1]], - [[2, 3, 2, 1, 1], - [2, 3, 2, 3, 2], - [3, 4, 3, 0, 3], - [-1, -1, -1, -1, -1]]]) + step_ids = np.random.randint( + 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) + parent_ids = np.random.randint( + 0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) beams = beam_search_ops.gather_tree( - step_ids=step_ids, parent_ids=parent_ids, - sequence_length=sequence_length) - self.assertAllEqual(expected_beams, beams.eval()) + step_ids=step_ids.astype(np.int32), + parent_ids=parent_ids.astype(np.int32), + max_sequence_lengths=max_sequence_lengths, + end_token=end_token) + + self.assertEqual((max_time, batch_size, beam_width), beams.shape) + beams_value = beams.eval() + for b in range(batch_size): + # Past max_sequence_lengths[b], we emit all end tokens. + b_value = beams_value[max_sequence_lengths[b]:, b, :] + self.assertAllClose(b_value, end_token * np.ones_like(b_value)) + for batch, beam in itertools.product( + range(batch_size), range(beam_width)): + v = np.squeeze(beams_value[:, batch, beam]) + if end_token in v: + found_bad = np.where(v == -1)[0] + self.assertEqual(0, len(found_bad)) + found = np.where(v == end_token)[0] + found = found[0] # First occurrence of end_token. + # If an end_token is found, everything before it should be a + # valid id and everything after it should be -1. + if found > 0: + self.assertAllEqual( + v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) + self.assertAllClose(v[found + 1:], + end_token * np.ones_like(v[found + 1:])) if __name__ == "__main__": diff --git a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py index 259c8e08ad914a6e57992426f9d0174de0e58388..839df079ee743c67b3eb6180bbf419f07ecb5435 100644 --- a/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py +++ b/tensorflow/contrib/seq2seq/python/ops/attention_wrapper.py @@ -342,7 +342,7 @@ class LuongAttention(_BaseAttentionMechanism): num_units: The depth of the attention mechanism. memory: The memory to query; usually the output of an RNN encoder. This tensor should be shaped `[batch_size, max_time, ...]`. - memory_sequence_length (optional): Sequence lengths for the batch entries + memory_sequence_length: (optional) Sequence lengths for the batch entries in memory. If provided, the memory tensor rows are masked with zeros for values past the respective sequence lengths. scale: Python boolean. Whether to scale the energy term. @@ -350,7 +350,7 @@ class LuongAttention(_BaseAttentionMechanism): probabilities. The default is @{tf.nn.softmax}. Other options include @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}. Its signature should be: `probabilities = probability_fn(score)`. - score_mask_value: (optional): The mask value for score before passing into + score_mask_value: (optional) The mask value for score before passing into `probability_fn`. The default is -inf. Only used if `memory_sequence_length` is not None. name: Name to use when creating ops. @@ -1009,6 +1009,37 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): name=None): """Construct the `AttentionWrapper`. + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + Args: cell: An instance of `RNNCell`. attention_mechanism: A list of `AttentionMechanism` instances or a single @@ -1157,6 +1188,11 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): @property def state_size(self): + """The `state_size` property of `AttentionWrapper`. + + Returns: + An `AttentionWrapperState` tuple containing shapes used by this object. + """ return AttentionWrapperState( cell_state=self._cell.state_size, time=tensor_shape.TensorShape([]), @@ -1167,6 +1203,25 @@ class AttentionWrapper(rnn_cell_impl.RNNCell): () for _ in self._attention_mechanisms)) # sometimes a TensorArray def zero_state(self, batch_size, dtype): + """Return an initial (zero) state tuple for this `AttentionWrapper`. + + **NOTE** Please see the initializer documentation for details of how + to call `zero_state` if using an `AttentionWrapper` with a + `BeamSearchDecoder`. + + Args: + batch_size: `0D` integer tensor: the batch size. + dtype: The internal state data type. + + Returns: + An `AttentionWrapperState` tuple containing zeroed out tensors and, + possibly, empty `TensorArray` objects. + + Raises: + ValueError: (or, possibly at runtime, InvalidArgument), if + `batch_size` does not match the output size of the encoder passed + to the wrapper object at initialization time. + """ with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]): if self._initial_cell_state is not None: cell_state = self._initial_cell_state diff --git a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py index 1855ea9999b9fb695a3c66a5c67000eaebf8eb27..5be0c92243da10af438be97fab982515266be1de 100644 --- a/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/beam_search_decoder.py @@ -20,9 +20,10 @@ from __future__ import print_function import collections +import numpy as np + from tensorflow.contrib.seq2seq.python.ops import beam_search_ops from tensorflow.contrib.seq2seq.python.ops import decoder -from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape @@ -130,7 +131,39 @@ def _check_maybe(t): class BeamSearchDecoder(decoder.Decoder): - """BeamSearch sampling decoder.""" + """BeamSearch sampling decoder. + + **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in + `AttentionWrapper`, then you must ensure that: + + - The encoder output has been tiled to `beam_width` via + @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). + - The `batch_size` argument passed to the `zero_state` method of this + wrapper is equal to `true_batch_size * beam_width`. + - The initial state created with `zero_state` above contains a + `cell_state` value containing properly tiled final state from the + encoder. + + An example: + + ``` + tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( + encoder_outputs, multiplier=beam_width) + tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( + encoder_final_state, multiplier=beam_width) + tiled_sequence_length = tf.contrib.seq2seq.tile_batch( + sequence_length, multiplier=beam_width) + attention_mechanism = MyFavoriteAttentionMechanism( + num_units=attention_depth, + memory=tiled_inputs, + memory_sequence_length=tiled_sequence_length) + attention_cell = AttentionWrapper(cell, attention_mechanism, ...) + decoder_initial_state = attention_cell.zero_state( + dtype, batch_size=true_batch_size * beam_width) + decoder_initial_state = decoder_initial_state.clone( + cell_state=tiled_encoder_final_state) + ``` + """ def __init__(self, cell, @@ -141,7 +174,7 @@ class BeamSearchDecoder(decoder.Decoder): beam_width, output_layer=None, length_penalty_weight=0.0): - """Initialize BeamSearchDecoder. + """Initialize the BeamSearchDecoder. Args: cell: An `RNNCell` instance. @@ -220,6 +253,20 @@ class BeamSearchDecoder(decoder.Decoder): output_shape_with_unknown_batch) return nest.map_structure(lambda s: s[1:], layer_output_shape) + @property + def tracks_own_finished(self): + """The BeamSearchDecoder shuffles its beams and their finished state. + + For this reason, it conflicts with the `dynamic_decode` function's + tracking of finished states. Setting this property to true avoids + early stopping of decoding due to mismanagement of the finished state + in `dynamic_decode`. + + Returns: + `True`. + """ + return True + @property def output_size(self): # Return the cell output and the id @@ -270,15 +317,23 @@ class BeamSearchDecoder(decoder.Decoder): output. sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. The sequence lengths determined for each beam during decode. + **NOTE** These are ignored; the updated sequence lengths are stored in + `final_state.lengths`. Returns: - outputs: An instance of FinalBeamSearchDecoderOutput where the + outputs: An instance of `FinalBeamSearchDecoderOutput` where the predicted_ids are the result of calling _gather_tree. - final_state: The same input instance of BeamSearchDecoderState. + final_state: The same input instance of `BeamSearchDecoderState`. """ + del sequence_lengths + # Get max_sequence_length across all beams for each batch. + max_sequence_lengths = math_ops.to_int32( + math_ops.reduce_max(final_state.lengths, axis=1)) predicted_ids = beam_search_ops.gather_tree( - outputs.predicted_ids, outputs.parent_ids, - sequence_length=sequence_lengths) + outputs.predicted_ids, + outputs.parent_ids, + max_sequence_lengths=max_sequence_lengths, + end_token=self._end_token) outputs = FinalBeamSearchDecoderOutput( beam_search_decoder_output=outputs, predicted_ids=predicted_ids) return outputs, final_state @@ -358,17 +413,17 @@ class BeamSearchDecoder(decoder.Decoder): We do this so that we can use nest and not run into problems with shapes. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - Either a reshaped version of t with dimension - [batch_size, beam_width, s] if t's first dimension is of size - batch_size*beam_width or t if not. + If `t` is a matrix or higher order tensor, then the return value is + `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is + returned unchanged. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 1: @@ -379,19 +434,19 @@ class BeamSearchDecoder(decoder.Decoder): def _maybe_merge_batch_beams(self, t, s): """Splits the tensor from a batch by beams into a batch of beams. - More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We - reshape this into [batch_size, beam_width, s] + More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, + then we reshape it to `[batch_size, beam_width] + s`. Args: - t: Tensor of dimension [batch_size*beam_width, s] - s: Tensor, Python int, or TensorShape. + t: `Tensor` of dimension `[batch_size * beam_width] + s`. + s: `Tensor`, Python int, or `TensorShape`. Returns: - A reshaped version of t with dimension [batch_size, beam_width, s]. + A reshaped version of t with shape `[batch_size, beam_width] + s`. Raises: - TypeError: If t is an instance of TensorArray. - ValueError: If the rank of t is not statically known. + TypeError: If `t` is an instance of `TensorArray`. + ValueError: If the rank of `t` is not statically known. """ _check_maybe(t) if t.shape.ndims >= 2: @@ -489,14 +544,12 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, # Calculate the continuation lengths by adding to all continuing beams. vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] lengths_to_add = array_ops.one_hot( - indices=array_ops.tile( - array_ops.reshape(end_token, [1, 1]), [batch_size, beam_width]), + indices=array_ops.fill([batch_size, beam_width], end_token), depth=vocab_size, - on_value=constant_op.constant(0, dtype=dtypes.int64), - off_value=constant_op.constant(1, dtype=dtypes.int64), + on_value=np.int64(0), off_value=np.int64(1), dtype=dtypes.int64) - add_mask = (1 - math_ops.to_int64(previously_finished)) - lengths_to_add = array_ops.expand_dims(add_mask, 2) * lengths_to_add + add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) + lengths_to_add *= array_ops.expand_dims(add_mask, 2) new_prediction_lengths = ( lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) @@ -522,6 +575,7 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, ops.convert_to_tensor(beam_width, dtype=dtypes.int32, name="beam_width"), num_available_beam) next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) + next_beam_scores.set_shape([static_batch_size, beam_width]) word_indices.set_shape([static_batch_size, beam_width]) @@ -531,9 +585,18 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, gather_from=total_probs, batch_size=batch_size, range_size=beam_width * vocab_size, - gather_shape=[-1]) - next_word_ids = math_ops.to_int32(word_indices % vocab_size) - next_beam_ids = math_ops.to_int32(word_indices / vocab_size) + gather_shape=[-1], + name="next_beam_probs") + # Note: just doing the following + # math_ops.to_int32(word_indices % vocab_size, + # name="next_beam_word_ids") + # would be a lot cleaner but for reasons unclear, that hides the results of + # the op which prevents capturing it with tfdbg debug ops. + raw_next_word_ids = math_ops.mod(word_indices, vocab_size, + name="next_beam_word_ids") + next_word_ids = math_ops.to_int32(raw_next_word_ids) + next_beam_ids = math_ops.to_int32(word_indices / vocab_size, + name="next_beam_parent_ids") # Append new ids to current predictions previously_finished = _tensor_gather_helper( @@ -543,15 +606,15 @@ def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, range_size=beam_width, gather_shape=[-1]) next_finished = math_ops.logical_or(previously_finished, - math_ops.equal(next_word_ids, end_token)) + math_ops.equal(next_word_ids, end_token), + name="next_beam_finished") # Calculate the length of the next predictions. - # 1. Finished beams remain unchanged - # 2. Beams that are now finished (EOS predicted) remain unchanged - # 3. Beams that are not yet finished have their length increased by 1 - lengths_to_add = math_ops.to_int64( - math_ops.not_equal(next_word_ids, end_token)) - lengths_to_add = (1 - math_ops.to_int64(next_finished)) * lengths_to_add + # 1. Finished beams remain unchanged. + # 2. Beams that are now finished (EOS predicted) have their length + # increased by 1. + # 3. Beams that are not yet finished have their length increased by 1. + lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) next_prediction_len = _tensor_gather_helper( gather_indices=next_beam_ids, gather_from=beam_state.lengths, @@ -609,13 +672,20 @@ def _get_scores(log_probs, sequence_lengths, length_penalty_weight): def _length_penalty(sequence_lengths, penalty_factor): """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. + Returns the length penalty tensor: + ``` + [(5+sequence_lengths)/6]**penalty_factor + ``` + where all operations are performed element-wise. + Args: - sequence_lengths: The sequence length of all hypotheses, a tensor - of shape [beam_size, vocab_size]. + sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. penalty_factor: A scalar that weights the length penalty. Returns: - The length penalty factor, a tensor fo shape [beam_size]. + If the penalty is `0`, returns the scalar `1.0`. Otherwise returns + the length penalty factor, a tensor with the same shape as + `sequence_lengths`. """ penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") penalty_factor.set_shape(()) # penalty should be a scalar. @@ -637,8 +707,7 @@ def _mask_probs(probs, eos_token, finished): eos_token: An int32 id corresponding to the EOS token to allocate probability to. finished: A boolean tensor of shape `[batch_size, beam_width]` that - specifies which - elements in the beam are finished already. + specifies which elements in the beam are finished already. Returns: A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished @@ -646,10 +715,6 @@ def _mask_probs(probs, eos_token, finished): probability on the EOS token. """ vocab_size = array_ops.shape(probs)[2] - finished_mask = array_ops.expand_dims( - math_ops.to_float(1. - math_ops.to_float(finished)), 2) - # These examples are not finished and we leave them - non_finished_examples = finished_mask * probs # All finished examples are replaced with a vector that has all # probability on EOS finished_row = array_ops.one_hot( @@ -658,8 +723,13 @@ def _mask_probs(probs, eos_token, finished): dtype=probs.dtype, on_value=0., off_value=probs.dtype.min) - finished_examples = (1. - finished_mask) * finished_row - return finished_examples + non_finished_examples + finished_probs = array_ops.tile( + array_ops.reshape(finished_row, [1, 1, -1]), + array_ops.concat([array_ops.shape(finished), [1]], 0)) + finished_mask = array_ops.tile( + array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) + + return array_ops.where(finished_mask, finished_probs, probs) def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, @@ -699,7 +769,7 @@ def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, def _tensor_gather_helper(gather_indices, gather_from, batch_size, - range_size, gather_shape): + range_size, gather_shape, name=None): """Helper for gathering the right indices from the tensor. This works by reshaping gather_from to gather_shape (e.g. [-1]) and then @@ -717,19 +787,22 @@ def _tensor_gather_helper(gather_indices, gather_from, batch_size, There, we want to preserve the attention_size elements, so gather_shape is [batch_size * beam_width, -1]. Then, upon reshape, we still have the attention_size as desired. + name: The tensor name for set of operations. By default this is + 'tensor_gather_helper'. The final output is named 'output'. Returns: output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] """ - range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1) - gather_indices = array_ops.reshape(gather_indices + range_, [-1]) - output = array_ops.gather( - array_ops.reshape(gather_from, gather_shape), gather_indices) - final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] - static_batch_size = tensor_util.constant_value(batch_size) - final_static_shape = (tensor_shape.TensorShape([static_batch_size]) - .concatenate( - gather_from.shape[1:1 + len(gather_shape)])) - output = array_ops.reshape(output, final_shape) - output.set_shape(final_static_shape) - return output + with ops.name_scope(name, "tensor_gather_helper"): + range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1) + gather_indices = array_ops.reshape(gather_indices + range_, [-1]) + output = array_ops.gather( + array_ops.reshape(gather_from, gather_shape), gather_indices) + final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] + static_batch_size = tensor_util.constant_value(batch_size) + final_static_shape = (tensor_shape.TensorShape([static_batch_size]) + .concatenate( + gather_from.shape[1:1 + len(gather_shape)])) + output = array_ops.reshape(output, final_shape, name="output") + output.set_shape(final_static_shape) + return output diff --git a/tensorflow/contrib/seq2seq/python/ops/decoder.py b/tensorflow/contrib/seq2seq/python/ops/decoder.py index fbe53fc60ada85c40970870c6d0bdb93d17ea6d4..f14974b9d5ca8cbcfd9f91086ca0a90ceff48f43 100644 --- a/tensorflow/contrib/seq2seq/python/ops/decoder.py +++ b/tensorflow/contrib/seq2seq/python/ops/decoder.py @@ -100,16 +100,36 @@ class Decoder(object): Returns: `(outputs, next_state, next_inputs, finished)`: `outputs` is an object - containing the decoder output, `next_state` is a (structure of) state tensors - and TensorArrays, `next_inputs` is the tensor that should be used as input for - the next step, `finished` is a boolean tensor telling whether the sequence - is complete, for each sequence in the batch. + containing the decoder output, `next_state` is a (structure of) state + tensors and TensorArrays, `next_inputs` is the tensor that should be used + as input for the next step, `finished` is a boolean tensor telling whether + the sequence is complete, for each sequence in the batch. """ raise NotImplementedError def finalize(self, outputs, final_state, sequence_lengths): raise NotImplementedError + @property + def tracks_own_finished(self): + """Describes whether the Decoder keeps track of finished states. + + Most decoders will emit a true/false `finished` value independently + at each time step. In this case, the `dynamic_decode` function keeps track + of which batch entries are already finished, and performs a logical OR to + insert new batches to the finished set. + + Some decoders, however, shuffle batches / beams between time steps and + `dynamic_decode` will mix up the finished state across these entries because + it does not track the reshuffle across time steps. In this case, it is + up to the decoder to declare that it will keep track of its own finished + state by setting this property to `True`. + + Returns: + Python bool. + """ + return False + def _create_zero_outputs(size, dtype, batch_size): """Create a zero outputs Tensor structure.""" @@ -232,7 +252,10 @@ def dynamic_decode(decoder, """ (next_outputs, decoder_state, next_inputs, decoder_finished) = decoder.step(time, inputs, state) - next_finished = math_ops.logical_or(decoder_finished, finished) + if decoder.tracks_own_finished: + next_finished = decoder_finished + else: + next_finished = math_ops.logical_or(decoder_finished, finished) if maximum_iterations is not None: next_finished = math_ops.logical_or( next_finished, time + 1 >= maximum_iterations) diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD index 8a1c9ba0a2ca01396bec662214f9a5f0d732f34b..67011c8fef6c4f54db2626ffe7ae1299bddbb352 100644 --- a/tensorflow/contrib/session_bundle/BUILD +++ b/tensorflow/contrib/session_bundle/BUILD @@ -136,7 +136,6 @@ py_test( ":gc", ":manifest_proto_py", "//tensorflow/core:protos_all_py", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", @@ -411,8 +410,6 @@ tf_cc_test( ":test_util", "//tensorflow/cc/saved_model:signature_constants", "//tensorflow/cc/saved_model:tag_constants", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core:test", diff --git a/tensorflow/contrib/signal/BUILD b/tensorflow/contrib/signal/BUILD index 8c11cf0d6450b5ea0f1d1af21c24a66c629cce90..b67090dd509f321c8d28436fa135fb871aee976d 100644 --- a/tensorflow/contrib/signal/BUILD +++ b/tensorflow/contrib/signal/BUILD @@ -5,12 +5,14 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) load("//tensorflow:tensorflow.bzl", "cuda_py_tests") +load("//tensorflow:tensorflow.bzl", "py_test") # @unused py_library( name = "signal_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), srcs_version = "PY2AND3", deps = [ + ":test_util", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", @@ -24,13 +26,39 @@ py_library( ], ) +py_library( + name = "test_util", + srcs = ["python/kernel_tests/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:tf_optimizer", + "//tensorflow/python:training", + ], +) + cuda_py_tests( name = "mel_ops_test", srcs = ["python/kernel_tests/mel_ops_test.py"], + additional_deps = [ + ":signal_py", + ":test_util", + "//third_party/py/numpy", + "//tensorflow/python:client_testlib", + ], +) + +cuda_py_tests( + name = "mfcc_ops_test", + srcs = ["python/kernel_tests/mfcc_ops_test.py"], additional_deps = [ ":signal_py", "//third_party/py/numpy", "//tensorflow/python:client_testlib", + "//tensorflow/python:framework", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:spectral_ops_test_util", ], ) @@ -56,6 +84,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/shape_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:array_ops", "//tensorflow/python:math_ops", @@ -93,6 +122,7 @@ cuda_py_tests( srcs = ["python/kernel_tests/window_ops_test.py"], additional_deps = [ ":signal_py", + ":test_util", "//third_party/py/numpy", "//tensorflow/python:client_testlib", "//tensorflow/python:framework", diff --git a/tensorflow/contrib/signal/__init__.py b/tensorflow/contrib/signal/__init__.py index 25123b097e380a7590ea7377d6c979e449ec96b0..6a2080bcec15a7ef29c54cc6394982b2e3709181 100644 --- a/tensorflow/contrib/signal/__init__.py +++ b/tensorflow/contrib/signal/__init__.py @@ -20,6 +20,8 @@ See the @{$python/contrib.signal} guide. @@hamming_window @@hann_window @@inverse_stft +@@inverse_stft_window_fn +@@mfccs_from_log_mel_spectrograms @@linear_to_mel_weight_matrix @@overlap_and_add @@stft @@ -27,6 +29,7 @@ See the @{$python/contrib.signal} guide. [hamming]: https://en.wikipedia.org/wiki/Window_function#Hamming_window [hann]: https://en.wikipedia.org/wiki/Window_function#Hann_window [mel]: https://en.wikipedia.org/wiki/Mel_scale +[mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -35,12 +38,14 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.signal.python.ops.mel_ops import linear_to_mel_weight_matrix +from tensorflow.contrib.signal.python.ops.mfcc_ops import mfccs_from_log_mel_spectrograms from tensorflow.contrib.signal.python.ops.reconstruction_ops import overlap_and_add from tensorflow.contrib.signal.python.ops.shape_ops import frame # `frame` used to be named `frames`, which is a noun and not a verb. # Keep an alias to `frames` for backwards compatibility. from tensorflow.contrib.signal.python.ops.shape_ops import frame as frames from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft +from tensorflow.contrib.signal.python.ops.spectral_ops import inverse_stft_window_fn from tensorflow.contrib.signal.python.ops.spectral_ops import stft from tensorflow.contrib.signal.python.ops.window_ops import hamming_window from tensorflow.contrib.signal.python.ops.window_ops import hann_window diff --git a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py index f107b53f01ca5422a57c6b03f6ec385d937bfead..b861476b67fc360f383465145ccd1cc620de5a99 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/mel_ops_test.py @@ -20,8 +20,10 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import mel_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test # mel spectrum constants and functions. @@ -159,6 +161,15 @@ class LinearToMelTest(test.TestCase): with self.assertRaises(ValueError): mel_ops.linear_to_mel_weight_matrix(dtype=dtypes.int32) + def test_constant_folding(self): + """Mel functions should be constant foldable.""" + for dtype in (dtypes.float16, dtypes.float32, dtypes.float64): + g = ops.Graph() + with g.as_default(): + mel_matrix = mel_ops.linear_to_mel_weight_matrix(dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [mel_matrix]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py new file mode 100644 index 0000000000000000000000000000000000000000..c04f1cf5bad358a14a1827df05a129339502c86f --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/mfcc_ops_test.py @@ -0,0 +1,54 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for mfcc_ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.signal.python.ops import mfcc_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import spectral_ops_test_util +from tensorflow.python.platform import test + + +# TODO(rjryan): We have no open source tests for MFCCs at the moment. Internally +# at Google, this code is tested against a reference implementation that follows +# HTK conventions. +class MFCCTest(test.TestCase): + + def test_error(self): + # num_mel_bins must be positive. + with self.assertRaises(ValueError): + signal = array_ops.zeros((2, 3, 0)) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal) + + # signal must be float32 + with self.assertRaises(ValueError): + signal = array_ops.zeros((2, 3, 5), dtype=dtypes.float64) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal) + + def test_basic(self): + """A basic test that the op runs on random input.""" + with spectral_ops_test_util.fft_kernel_label_map(): + with self.test_session(use_gpu=True): + signal = random_ops.random_normal((2, 3, 5)) + mfcc_ops.mfccs_from_log_mel_spectrograms(signal).eval() + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py index 8633ced599f137da08a4181ec9cbf4b48517199d..1c052354b8afcc5fd8a53b783cc5c676588cf48c 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/shape_ops_test.py @@ -20,9 +20,11 @@ from __future__ import print_function import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import shape_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test @@ -334,5 +336,19 @@ class FrameTest(test.TestCase): signal, signal_shape, frames, frames.shape.as_list()) self.assertLess(error, 2e-5) + def test_constant_folding(self): + """frame should be constant foldable for constant inputs.""" + for pad_end in [False, True]: + g = ops.Graph() + with g.as_default(): + frame_length, frame_step = 32, 16 + signal_shape = (2, 128) + signal = array_ops.ones(signal_shape) + frames = shape_ops.frame(signal, frame_length, frame_step, + pad_end=pad_end) + rewritten_graph = test_util.grappler_optimize(g, [frames]) + self.assertEqual(1, len(rewritten_graph.node)) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py index 305a2b2eb9858b381988335caa5cc6b2e11e2bac..03d6da7765ba5249a9fb22f56a469cf07c310479 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/spectral_ops_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.signal.python.ops import spectral_ops +from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import gradients_impl @@ -59,7 +60,11 @@ class SpectralOpsTest(test.TestCase): @staticmethod def _np_inverse_stft(stft, fft_length, hop_length, window_length): - frames = np.fft.irfft(stft, fft_length)[..., :window_length] + frames = np.fft.irfft(stft, fft_length) + # Pad or truncate frames's inner dimension to window_length. + frames = frames[..., :window_length] + frames = np.pad(frames, [[0, 0]] * (frames.ndim - 1) + + [[0, max(0, window_length - frames.shape[-1])]], "constant") window = SpectralOpsTest._np_hann_periodic_window(window_length) return SpectralOpsTest._np_overlap_add(frames * window, hop_length) @@ -79,12 +84,27 @@ class SpectralOpsTest(test.TestCase): self.test_session(use_gpu=True)) as sess: actual_stft = spectral_ops.stft( signal, frame_length, frame_step, fft_length, pad_end=False) + signal_ph = array_ops.placeholder(dtype=dtypes.as_dtype(signal.dtype)) + actual_stft_from_ph = spectral_ops.stft( + signal_ph, frame_length, frame_step, fft_length, pad_end=False) actual_inverse_stft = spectral_ops.inverse_stft( actual_stft, frame_length, frame_step, fft_length) - actual_stft, actual_inverse_stft = sess.run( - [actual_stft, actual_inverse_stft]) + actual_stft, actual_stft_from_ph, actual_inverse_stft = sess.run( + [actual_stft, actual_stft_from_ph, actual_inverse_stft], + feed_dict={signal_ph: signal}) + + actual_stft_ph = array_ops.placeholder(dtype=actual_stft.dtype) + actual_inverse_stft_from_ph = sess.run( + spectral_ops.inverse_stft( + actual_stft_ph, frame_length, frame_step, fft_length), + feed_dict={actual_stft_ph: actual_stft}) + + # Confirm that there is no difference in output when shape/rank is fully + # unknown or known. + self.assertAllClose(actual_stft, actual_stft_from_ph) + self.assertAllClose(actual_inverse_stft, actual_inverse_stft_from_ph) expected_stft = SpectralOpsTest._np_stft( signal, fft_length, frame_step, frame_length) @@ -95,31 +115,6 @@ class SpectralOpsTest(test.TestCase): self.assertAllClose( expected_inverse_stft, actual_inverse_stft, 1e-4, 1e-4) - def _compare_round_trip(self, signal, frame_length, frame_step, fft_length): - with spectral_ops_test_util.fft_kernel_label_map(), ( - self.test_session(use_gpu=True)) as sess: - stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, - pad_end=False) - inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, - fft_length) - signal, inverse_stft = sess.run([signal, inverse_stft]) - - # Since the shapes can differ due to padding, pad both signals to the max - # of their lengths. - max_length = max(signal.shape[0], inverse_stft.shape[0]) - signal = np.pad(signal, (0, max_length - signal.shape[0]), "constant") - inverse_stft = np.pad(inverse_stft, - (0, max_length - inverse_stft.shape[0]), "constant") - - # Ignore the frame_length samples at either edge. - start = frame_length - end = signal.shape[0] - frame_length - ratio = signal[start:end] / inverse_stft[start:end] - - # Check that the inverse and original signal are equal up to a constant - # factor. - self.assertLess(np.var(ratio), 2e-5) - def test_shapes(self): with spectral_ops_test_util.fft_kernel_label_map(), ( self.test_session(use_gpu=True)): @@ -142,6 +137,11 @@ class SpectralOpsTest(test.TestCase): self.assertAllEqual([64, 9], stft.shape.as_list()) self.assertAllEqual([64, 9], stft.eval().shape) + stft = spectral_ops.stft(signal, frame_length=16, frame_step=8, + fft_length=8, pad_end=True) + self.assertAllEqual([64, 5], stft.shape.as_list()) + self.assertAllEqual([64, 5], stft.eval().shape) + stft = np.zeros((32, 9)).astype(np.complex64) inverse_stft = spectral_ops.inverse_stft(stft, frame_length=8, @@ -156,6 +156,7 @@ class SpectralOpsTest(test.TestCase): test_configs = [ (512, 64, 32, 64), (512, 64, 64, 64), + (512, 72, 64, 64), (512, 64, 25, 64), (512, 25, 15, 36), (123, 23, 5, 42), @@ -166,23 +167,105 @@ class SpectralOpsTest(test.TestCase): self._compare(signal, frame_length, frame_step, fft_length) def test_stft_round_trip(self): - # Tuples of (signal_length, frame_length, frame_step, fft_length). + # Tuples of (signal_length, frame_length, frame_step, fft_length, + # threshold, corrected_threshold). test_configs = [ # 87.5% overlap. - (4096, 256, 32, 256), + (4096, 256, 32, 256, 1e-5, 1e-6), # 75% overlap. - (4096, 256, 64, 256), + (4096, 256, 64, 256, 1e-5, 1e-6), # Odd frame hop. - (4096, 128, 25, 128), + (4096, 128, 25, 128, 1e-3, 1e-6), # Odd frame length. - (4096, 127, 32, 128), + (4096, 127, 32, 128, 1e-3, 1e-6), + # 50% overlap. + (4096, 128, 64, 128, 0.40, 1e-6), ] - for signal_length, frame_length, frame_step, fft_length in test_configs: - # Generate a 440Hz signal at 8kHz sample rate. - signal = math_ops.sin(2 * np.pi * 440 / 8000 * - math_ops.to_float(math_ops.range(signal_length))) - self._compare_round_trip(signal, frame_length, frame_step, fft_length) + for (signal_length, frame_length, frame_step, fft_length, threshold, + corrected_threshold) in test_configs: + # Generate a random white Gaussian signal. + signal = random_ops.random_normal([signal_length]) + + with spectral_ops_test_util.fft_kernel_label_map(), ( + self.test_session(use_gpu=True)) as sess: + stft = spectral_ops.stft(signal, frame_length, frame_step, fft_length, + pad_end=False) + inverse_stft = spectral_ops.inverse_stft(stft, frame_length, frame_step, + fft_length) + inverse_stft_corrected = spectral_ops.inverse_stft( + stft, frame_length, frame_step, fft_length, + window_fn=spectral_ops.inverse_stft_window_fn(frame_step)) + signal, inverse_stft, inverse_stft_corrected = sess.run( + [signal, inverse_stft, inverse_stft_corrected]) + + # Truncate signal to the size of inverse stft. + signal = signal[:inverse_stft.shape[0]] + + # Ignore the frame_length samples at either edge. + signal = signal[frame_length:-frame_length] + inverse_stft = inverse_stft[frame_length:-frame_length] + inverse_stft_corrected = inverse_stft_corrected[ + frame_length:-frame_length] + + # Check that the inverse and original signal are close up to a scale + # factor. + inverse_stft_scaled = inverse_stft / np.mean(np.abs(inverse_stft)) + signal_scaled = signal / np.mean(np.abs(signal)) + self.assertLess(np.std(inverse_stft_scaled - signal_scaled), threshold) + + # Check that the inverse with correction and original signal are close. + self.assertLess(np.std(inverse_stft_corrected - signal), + corrected_threshold) + + def test_inverse_stft_window_fn(self): + """Test that inverse_stft_window_fn has unit gain at each window phase.""" + # Tuples of (frame_length, frame_step). + test_configs = [ + (256, 32), + (256, 64), + (128, 25), + (127, 32), + (128, 64), + ] + + for (frame_length, frame_step) in test_configs: + hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32) + inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step) + inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32) + + with self.test_session(use_gpu=True) as sess: + hann_window, inverse_window = sess.run([hann_window, inverse_window]) + + # Expect unit gain at each phase of the window. + product_window = hann_window * inverse_window + for i in range(frame_step): + self.assertAllClose(1.0, np.sum(product_window[i::frame_step])) + + def test_inverse_stft_window_fn_special_case(self): + """Test inverse_stft_window_fn in special overlap = 3/4 case.""" + # Cases in which frame_length is an integer multiple of 4 * frame_step are + # special because they allow exact reproduction of the waveform with a + # squared Hann window (Hann window in both forward and reverse transforms). + # In the case where frame_length = 4 * frame_step, that combination + # produces a constant gain of 1.5, and so the corrected window will be the + # Hann window / 1.5. + + # Tuples of (frame_length, frame_step). + test_configs = [ + (256, 64), + (128, 32), + ] + + for (frame_length, frame_step) in test_configs: + hann_window = window_ops.hann_window(frame_length, dtype=dtypes.float32) + inverse_window_fn = spectral_ops.inverse_stft_window_fn(frame_step) + inverse_window = inverse_window_fn(frame_length, dtype=dtypes.float32) + + with self.test_session(use_gpu=True) as sess: + hann_window, inverse_window = sess.run([hann_window, inverse_window]) + + self.assertAllClose(hann_window, inverse_window * 1.5) @staticmethod def _compute_stft_gradient(signal, frame_length=32, frame_step=16, diff --git a/tensorflow/contrib/signal/python/kernel_tests/test_util.py b/tensorflow/contrib/signal/python/kernel_tests/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..9a3603b6a97ef7c3a4b940b83281ebceda93c9db --- /dev/null +++ b/tensorflow/contrib/signal/python/kernel_tests/test_util.py @@ -0,0 +1,46 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test utilities for tf.contrib.signal.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import rewriter_config_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import saver + + +def grappler_optimize(graph, fetches=None, rewriter_config=None): + """Tries to optimize the provided graph using grappler. + + Args: + graph: A @{tf.Graph} instance containing the graph to optimize. + fetches: An optional list of `Tensor`s to fetch (i.e. not optimize away). + Grappler uses the 'train_op' collection to look for fetches, so if not + provided this collection should be non-empty. + rewriter_config: An optional @{tf.RewriterConfig} to use when rewriting the + graph. + + Returns: + A @{tf.GraphDef} containing the rewritten graph. + """ + if rewriter_config is None: + rewriter_config = rewriter_config_pb2.RewriterConfig() + if fetches is not None: + for fetch in fetches: + graph.add_to_collection('train_op', fetch) + metagraph = saver.export_meta_graph(graph_def=graph.as_graph_def()) + return tf_optimizer.OptimizeGraph(rewriter_config, metagraph) diff --git a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py index c3e0464596244b331906dab47cee349c1ea737b5..5a464699dac5a737e0c6e0122a4a6699e945f695 100644 --- a/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py +++ b/tensorflow/contrib/signal/python/kernel_tests/window_ops_test.py @@ -22,8 +22,10 @@ import functools import numpy as np +from tensorflow.contrib.signal.python.kernel_tests import test_util from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.platform import test @@ -91,6 +93,17 @@ class WindowOpsTest(test.TestCase): functools.partial(_scipy_raised_cosine, a=0.54, b=0.46), window_ops.hamming_window) + def test_constant_folding(self): + """Window functions should be constant foldable for constant inputs.""" + for window_fn in (window_ops.hann_window, window_ops.hamming_window): + for dtype, _ in self._dtypes: + for periodic in [False, True]: + g = ops.Graph() + with g.as_default(): + window = window_fn(100, periodic=periodic, dtype=dtype) + rewritten_graph = test_util.grappler_optimize(g, [window]) + self.assertEqual(1, len(rewritten_graph.node)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/signal/python/ops/mfcc_ops.py b/tensorflow/contrib/signal/python/ops/mfcc_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..7bc7b57cd4f1033a8bda0845ccd8e777e0213d6b --- /dev/null +++ b/tensorflow/contrib/signal/python/ops/mfcc_ops.py @@ -0,0 +1,108 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Mel-Frequency Cepstral Coefficients (MFCCs) ops.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import spectral_ops + + +def mfccs_from_log_mel_spectrograms(log_mel_spectrograms, name=None): + """Computes [MFCCs][mfcc] of `log_mel_spectrograms`. + + Implemented with GPU-compatible ops and supports gradients. + + [Mel-Frequency Cepstral Coefficient (MFCC)][mfcc] calculation consists of + taking the DCT-II of a log-magnitude mel-scale spectrogram. [HTK][htk]'s MFCCs + use a particular scaling of the DCT-II which is almost orthogonal + normalization. We follow this convention. + + All `num_mel_bins` MFCCs are returned and it is up to the caller to select + a subset of the MFCCs based on their application. For example, it is typical + to only use the first few for speech recognition, as this results in + an approximately pitch-invariant representation of the signal. + + For example: + + ```python + sample_rate = 16000.0 + # A Tensor of [batch_size, num_samples] mono PCM samples in the range [-1, 1]. + pcm = tf.placeholder(tf.float32, [None, None]) + + # A 1024-point STFT with frames of 64 ms and 75% overlap. + stfts = tf.contrib.signal.stft(pcm, frame_length=1024, frame_step=256, + fft_length=1024) + spectrograms = tf.abs(stft) + + # Warp the linear scale spectrograms into the mel-scale. + num_spectrogram_bins = stfts.shape[-1].value + lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 80 + linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( + num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, + upper_edge_hertz) + mel_spectrograms = tf.tensordot( + spectrograms, linear_to_mel_weight_matrix, 1) + mel_spectrograms.set_shape(spectrograms.shape[:-1].concatenate( + linear_to_mel_weight_matrix.shape[-1:])) + + # Compute a stabilized log to get log-magnitude mel-scale spectrograms. + log_mel_spectrograms = tf.log(mel_spectrograms + 1e-6) + + # Compute MFCCs from log_mel_spectrograms and take the first 13. + mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms( + log_mel_spectrograms)[..., :13] + ``` + + Args: + log_mel_spectrograms: A `[..., num_mel_bins]` `float32` `Tensor` of + log-magnitude mel-scale spectrograms. + name: An optional name for the operation. + Returns: + A `[..., num_mel_bins]` `float32` `Tensor` of the MFCCs of + `log_mel_spectrograms`. + + Raises: + ValueError: If `num_mel_bins` is not positive. + + [mfcc]: https://en.wikipedia.org/wiki/Mel-frequency_cepstrum + [htk]: https://en.wikipedia.org/wiki/HTK_(software) + """ + with ops.name_scope(name, 'mfccs_from_log_mel_spectrograms', + [log_mel_spectrograms]): + # Compute the DCT-II of the resulting log-magnitude mel-scale spectrogram. + # The DCT used in HTK scales every basis vector by sqrt(2/N), which is the + # scaling required for an "orthogonal" DCT-II *except* in the 0th bin, where + # the true orthogonal DCT (as implemented by scipy) scales by sqrt(1/N). For + # this reason, we don't apply orthogonal normalization and scale the DCT by + # `0.5 * sqrt(2/N)` manually. + log_mel_spectrograms = ops.convert_to_tensor(log_mel_spectrograms, + dtype=dtypes.float32) + if (log_mel_spectrograms.shape.ndims and + log_mel_spectrograms.shape[-1].value is not None): + num_mel_bins = log_mel_spectrograms.shape[-1].value + if num_mel_bins == 0: + raise ValueError('num_mel_bins must be positive. Got: %s' % + log_mel_spectrograms) + else: + num_mel_bins = array_ops.shape(log_mel_spectrograms)[-1] + + dct2 = spectral_ops.dct(log_mel_spectrograms) + return dct2 * math_ops.rsqrt(num_mel_bins * 2.0) diff --git a/tensorflow/contrib/signal/python/ops/spectral_ops.py b/tensorflow/contrib/signal/python/ops/spectral_ops.py index 950d8f471c6b34ecd7488b4434776a333d2fa782..bca2e01d7bbefb18fd69a0eba27e3afb8f636724 100644 --- a/tensorflow/contrib/signal/python/ops/spectral_ops.py +++ b/tensorflow/contrib/signal/python/ops/spectral_ops.py @@ -28,6 +28,7 @@ from tensorflow.contrib.signal.python.ops import window_ops from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import spectral_ops @@ -59,8 +60,7 @@ def stft(signals, frame_length, frame_step, fft_length=None, Raises: ValueError: If `signals` is not at least rank 1, `frame_length` is - not scalar, `frame_step` is not scalar, or `frame_length` - is greater than `fft_length`. + not scalar, or `frame_step` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -78,15 +78,6 @@ def stft(signals, frame_length, frame_step, fft_length=None, else: fft_length = ops.convert_to_tensor(fft_length, name='fft_length') - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) - framed_signals = shape_ops.frame( signals, frame_length, frame_step, pad_end=pad_end) @@ -100,6 +91,67 @@ def stft(signals, frame_length, frame_step, fft_length=None, return spectral_ops.rfft(framed_signals, [fft_length]) +def inverse_stft_window_fn(frame_step, + forward_window_fn=functools.partial( + window_ops.hann_window, periodic=True), + name=None): + """Generates a window function that can be used in `inverse_stft`. + + Constructs a window that is equal to the forward window with a further + pointwise amplitude correction. `inverse_stft_window_fn` is equivalent to + `forward_window_fn` in the case where it would produce an exact inverse. + + See examples in `inverse_stft` documentation for usage. + + Args: + frame_step: An integer scalar `Tensor`. The number of samples to step. + forward_window_fn: window_fn used in the forward transform, `stft`. + name: An optional name for the operation. + + Returns: + A callable that takes a window length and a `dtype` keyword argument and + returns a `[window_length]` `Tensor` of samples in the provided datatype. + The returned window is suitable for reconstructing original waveform in + inverse_stft. + """ + with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]): + frame_step = ops.convert_to_tensor(frame_step, name='frame_step') + frame_step.shape.assert_has_rank(0) + + def inverse_stft_window_fn_inner(frame_length, dtype): + """Computes a window that can be used in `inverse_stft`. + + Args: + frame_length: An integer scalar `Tensor`. The window length in samples. + dtype: Data type of waveform passed to `stft`. + + Returns: + A window suitable for reconstructing original waveform in `inverse_stft`. + + Raises: + ValueError: If `frame_length` is not scalar, `forward_window_fn` is not a + callable that takes a window length and a `dtype` keyword argument and + returns a `[window_length]` `Tensor` of samples in the provided datatype + `frame_step` is not scalar, or `frame_step` is not scalar. + """ + with ops.name_scope(name, 'inverse_stft_window_fn', [forward_window_fn]): + frame_length = ops.convert_to_tensor(frame_length, name='frame_length') + frame_length.shape.assert_has_rank(0) + + # Use equation 7 from Griffin + Lim. + forward_window = forward_window_fn(frame_length, dtype=dtype) + denom = math_ops.square(forward_window) + overlaps = -(-frame_length // frame_step) # Ceiling division. + denom = array_ops.pad(denom, [(0, overlaps * frame_step - frame_length)]) + denom = array_ops.reshape(denom, [overlaps, frame_step]) + denom = math_ops.reduce_sum(denom, 0, keep_dims=True) + denom = array_ops.tile(denom, [overlaps, 1]) + denom = array_ops.reshape(denom, [overlaps * frame_step]) + + return forward_window / denom[:frame_length] + return inverse_stft_window_fn_inner + + def inverse_stft(stfts, frame_length, frame_step, @@ -109,6 +161,38 @@ def inverse_stft(stfts, name=None): """Computes the inverse [Short-time Fourier Transform][stft] of `stfts`. + To reconstruct an original waveform, a complimentary window function should + be used in inverse_stft. Such a window function can be constructed with + tf.contrib.signal.inverse_stft_window_fn. + + Example: + + ```python + frame_length = 400 + frame_step = 160 + waveform = tf.placeholder(dtype=tf.float32, shape=[1000]) + stft = tf.contrib.signal.stft(waveform, frame_length, frame_step) + inverse_stft = tf.contrib.signal.inverse_stft( + stft, frame_length, frame_step, + window_fn=tf.contrib.signal.inverse_stft_window_fn(frame_step)) + ``` + + if a custom window_fn is used in stft, it must be passed to + inverse_stft_window_fn: + + ```python + frame_length = 400 + frame_step = 160 + window_fn = functools.partial(window_ops.hamming_window, periodic=True), + waveform = tf.placeholder(dtype=tf.float32, shape=[1000]) + stft = tf.contrib.signal.stft( + waveform, frame_length, frame_step, window_fn=window_fn) + inverse_stft = tf.contrib.signal.inverse_stft( + stft, frame_length, frame_step, + window_fn=tf.contrib.signal.inverse_stft_window_fn( + frame_step, forward_window_fn=window_fn)) + ``` + Implemented with GPU-compatible ops and supports gradients. Args: @@ -131,8 +215,7 @@ def inverse_stft(stfts, Raises: ValueError: If `stfts` is not at least rank 2, `frame_length` is not scalar, - `frame_step` is not scalar, or `fft_length` is not scalar, or - `frame_length` is greater than `fft_length`. + `frame_step` is not scalar, or `fft_length` is not scalar. [stft]: https://en.wikipedia.org/wiki/Short-time_Fourier_transform """ @@ -149,16 +232,40 @@ def inverse_stft(stfts, fft_length = ops.convert_to_tensor(fft_length, name='fft_length') fft_length.shape.assert_has_rank(0) - frame_length_static = tensor_util.constant_value( - frame_length) - fft_length_static = tensor_util.constant_value(fft_length) - if (frame_length_static is not None and fft_length_static is not None and - frame_length_static > fft_length_static): - raise ValueError('frame_length (%d) may not be larger than ' - 'fft_length (%d)' % (frame_length_static, - fft_length_static)) + real_frames = spectral_ops.irfft(stfts, [fft_length]) + + # frame_length may be larger or smaller than fft_length, so we pad or + # truncate real_frames to frame_length. + frame_length_static = tensor_util.constant_value(frame_length) + # If we don't know the shape of real_frames's inner dimension, pad and + # truncate to frame_length. + if (frame_length_static is None or + real_frames.shape.ndims is None or + real_frames.shape[-1].value is None): + real_frames = real_frames[..., :frame_length] + real_frames_rank = array_ops.rank(real_frames) + real_frames_shape = array_ops.shape(real_frames) + paddings = array_ops.concat( + [array_ops.zeros([real_frames_rank - 1, 2], + dtype=frame_length.dtype), + [[0, math_ops.maximum(0, frame_length - real_frames_shape[-1])]]], 0) + real_frames = array_ops.pad(real_frames, paddings) + # We know real_frames's last dimension and frame_length statically. If they + # are different, then pad or truncate real_frames to frame_length. + elif real_frames.shape[-1].value > frame_length_static: + real_frames = real_frames[..., :frame_length_static] + elif real_frames.shape[-1].value < frame_length_static: + pad_amount = frame_length_static - real_frames.shape[-1].value + real_frames = array_ops.pad(real_frames, + [[0, 0]] * (real_frames.shape.ndims - 1) + + [[0, pad_amount]]) - real_frames = spectral_ops.irfft(stfts, [fft_length])[..., :frame_length] + # The above code pads the inner dimension of real_frames to frame_length, + # but it does so in a way that may not be shape-inference friendly. + # Restore shape information if we are able to. + if frame_length_static is not None and real_frames.shape.ndims is not None: + real_frames.set_shape([None] * (real_frames.shape.ndims - 1) + + [frame_length_static]) # Optionally window and overlap-add the inner 2 dimensions of real_frames # into a single [samples] dimension. diff --git a/tensorflow/contrib/signal/python/ops/util_ops.py b/tensorflow/contrib/signal/python/ops/util_ops.py index eee829d799eb149bfb2af0dfe92c9fc1b55c452c..817c9b97d68515be640a3ca966bd2d43fd83b864 100644 --- a/tensorflow/contrib/signal/python/ops/util_ops.py +++ b/tensorflow/contrib/signal/python/ops/util_ops.py @@ -18,7 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import fractions + from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -51,6 +54,13 @@ def gcd(a, b, name=None): if not b.dtype.is_integer: raise ValueError('b must be an integer type. Got: %s' % b.dtype) + # TPU requires static shape inference. GCD is used for subframe size + # computation, so we should prefer static computation where possible. + const_a = tensor_util.constant_value(a) + const_b = tensor_util.constant_value(b) + if const_a is not None and const_b is not None: + return ops.convert_to_tensor(fractions.gcd(const_a, const_b)) + cond = lambda _, b: math_ops.greater(b, array_ops.zeros_like(b)) body = lambda a, b: [b, math_ops.mod(a, b)] a, b = control_flow_ops.while_loop(cond, body, [a, b], back_prop=False) diff --git a/tensorflow/contrib/signal/python/ops/window_ops.py b/tensorflow/contrib/signal/python/ops/window_ops.py index 07a847dd2a440254d50759308006c7121eee13f2..50094010dc75cf8b3c62da5e3a7ed5e995e6df41 100644 --- a/tensorflow/contrib/signal/python/ops/window_ops.py +++ b/tensorflow/contrib/signal/python/ops/window_ops.py @@ -23,6 +23,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -105,6 +106,9 @@ def _raised_cosine_window(name, default_name, window_length, periodic, window_length = ops.convert_to_tensor(window_length, dtype=dtypes.int32, name='window_length') window_length.shape.assert_has_rank(0) + window_length_const = tensor_util.constant_value(window_length) + if window_length_const == 1: + return array_ops.ones([1], dtype=dtype) periodic = math_ops.cast( ops.convert_to_tensor(periodic, dtype=dtypes.bool, name='periodic'), dtypes.int32) @@ -115,6 +119,8 @@ def _raised_cosine_window(name, default_name, window_length, periodic, count = math_ops.cast(math_ops.range(window_length), dtype) cos_arg = constant_op.constant(2 * np.pi, dtype=dtype) * count / n + if window_length_const is not None: + return math_ops.cast(a - b * math_ops.cos(cos_arg), dtype=dtype) return control_flow_ops.cond( math_ops.equal(window_length, 1), lambda: array_ops.ones([1], dtype=dtype), diff --git a/tensorflow/contrib/slim/BUILD b/tensorflow/contrib/slim/BUILD index d2664b612cdbcae3a346b68e9caee654c48a69cd..23c23af2f4815c3b1d75eb955b9026dfb9b00194 100644 --- a/tensorflow/contrib/slim/BUILD +++ b/tensorflow/contrib/slim/BUILD @@ -48,7 +48,6 @@ py_library( srcs = ["python/slim/learning.py"], srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/training:training_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:client", @@ -78,7 +77,6 @@ py_test( "//tensorflow/contrib/losses:losses_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", diff --git a/tensorflow/contrib/slim/README.md b/tensorflow/contrib/slim/README.md index c0aa6d445acfc99ef9da9a54fc269babee754951..f7a85557ca3df6325502da1052c96beff3c5ae08 100644 --- a/tensorflow/contrib/slim/README.md +++ b/tensorflow/contrib/slim/README.md @@ -237,7 +237,7 @@ One way to reduce this code duplication would be via a `for` loop: ```python net = ... for i in range(3): - net = slim.conv2d(net, 256, [3, 3], scope='conv3_' % (i+1)) + net = slim.conv2d(net, 256, [3, 3], scope='conv3_%d' % (i+1)) net = slim.max_pool2d(net, [2, 2], scope='pool2') ``` @@ -574,7 +574,7 @@ with tf.Graph().as_default(): images, labels = ... # Define the model: - predictions = vgg.vgg16(images, is_training=True) + predictions = vgg.vgg_16(images, is_training=True) # Specify the loss function: slim.losses.softmax_cross_entropy(predictions, labels) diff --git a/tensorflow/contrib/slim/python/slim/data/BUILD b/tensorflow/contrib/slim/python/slim/data/BUILD index fc71a5fe415d4d34bd38e43bf33cefffcddaea6f..5daabbd62e7e63608a7a86a8b7fb0bc0d570b28b 100644 --- a/tensorflow/contrib/slim/python/slim/data/BUILD +++ b/tensorflow/contrib/slim/python/slim/data/BUILD @@ -68,13 +68,13 @@ py_test( ":tfexample_decoder", "//tensorflow/contrib/slim:queues", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:image_ops", "//tensorflow/python:io_ops", "//tensorflow/python:parsing_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", ], ) @@ -187,6 +187,7 @@ py_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:image_ops", + "//tensorflow/python:lookup_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", "//third_party/py/numpy", diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py index f9449095be0cb53a8c762eaa70f005f01645743d..0544404e9e252cca6d3650b805b91be25d705eea 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder.py @@ -135,7 +135,10 @@ class BoundingBox(ItemHandler): """ sides = [] for key in self._full_keys: - side = array_ops.expand_dims(keys_to_tensors[key].values, 0) + side = keys_to_tensors[key] + if isinstance(side, sparse_tensor.SparseTensor): + side = side.values + side = array_ops.expand_dims(side, 0) sides.append(side) bounding_box = array_ops.concat(sides, 0) @@ -204,6 +207,76 @@ class Tensor(ItemHandler): return tensor +class LookupTensor(Tensor): + """An ItemHandler that returns a parsed Tensor, the result of a lookup.""" + + def __init__(self, + tensor_key, + table, + shape_keys=None, + shape=None, + default_value=''): + """Initializes the LookupTensor handler. + + See Tensor. Simply calls a vocabulary (most often, a label mapping) lookup. + + Args: + tensor_key: the name of the `TFExample` feature to read the tensor from. + table: A tf.lookup table. + shape_keys: Optional name or list of names of the TF-Example feature in + which the tensor shape is stored. If a list, then each corresponds to + one dimension of the shape. + shape: Optional output shape of the `Tensor`. If provided, the `Tensor` is + reshaped accordingly. + default_value: The value used when the `tensor_key` is not found in a + particular `TFExample`. + + Raises: + ValueError: if both `shape_keys` and `shape` are specified. + """ + self._table = table + super(LookupTensor, self).__init__(tensor_key, shape_keys, shape, + default_value) + + def tensors_to_item(self, keys_to_tensors): + unmapped_tensor = super(LookupTensor, self).tensors_to_item(keys_to_tensors) + return self._table.lookup(unmapped_tensor) + + +class BackupHandler(ItemHandler): + """An ItemHandler that tries two ItemHandlers in order.""" + + def __init__(self, handler, backup): + """Initializes the BackupHandler handler. + + If the first Handler's tensors_to_item returns a Tensor with no elements, + the second Handler is used. + + Args: + handler: The primary ItemHandler. + backup: The backup ItemHandler. + + Raises: + ValueError: if either is not an ItemHandler. + """ + if not isinstance(handler, ItemHandler): + raise ValueError('Primary handler is of type %s instead of ItemHandler' + % type(handler)) + if not isinstance(backup, ItemHandler): + raise ValueError('Backup handler is of type %s instead of ItemHandler' + % type(backup)) + self._handler = handler + self._backup = backup + super(BackupHandler, self).__init__(handler.keys + backup.keys) + + def tensors_to_item(self, keys_to_tensors): + item = self._handler.tensors_to_item(keys_to_tensors) + return control_flow_ops.cond( + pred=math_ops.equal(math_ops.reduce_prod(array_ops.shape(item)), 0), + true_fn=lambda: self._backup.tensors_to_item(keys_to_tensors), + false_fn=lambda: item) + + class SparseTensor(ItemHandler): """An ItemHandler for SparseTensors.""" diff --git a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py index 96606b9c0e5b19a360f45ffe9922874cabe621e8..d783d4fef42bb2acffe7eb8b155c5efaed7896d9 100644 --- a/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py +++ b/tensorflow/contrib/slim/python/slim/data/tfexample_decoder_test.py @@ -28,6 +28,7 @@ from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import image_ops +from tensorflow.python.ops import lookup_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import parsing_ops from tensorflow.python.platform import test @@ -692,7 +693,7 @@ class TFExampleDecoderTest(test.TestCase): else: self.assertAllClose(image, decoded_image, atol=0) - def testDecodeExampleWithBoundingBox(self): + def testDecodeExampleWithBoundingBoxSparse(self): num_bboxes = 10 np_ymin = np.random.rand(num_bboxes, 1) np_xmin = np.random.rand(num_bboxes, 1) @@ -731,6 +732,49 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllClose(np_bboxes, bboxes) + def testDecodeExampleWithBoundingBoxDense(self): + num_bboxes = 10 + np_ymin = np.random.rand(num_bboxes, 1) + np_xmin = np.random.rand(num_bboxes, 1) + np_ymax = np.random.rand(num_bboxes, 1) + np_xmax = np.random.rand(num_bboxes, 1) + np_bboxes = np.hstack([np_ymin, np_xmin, np_ymax, np_xmax]) + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/bbox/ymin': self._EncodedFloatFeature(np_ymin), + 'image/object/bbox/xmin': self._EncodedFloatFeature(np_xmin), + 'image/object/bbox/ymax': self._EncodedFloatFeature(np_ymax), + 'image/object/bbox/xmax': self._EncodedFloatFeature(np_xmax), + })) + serialized_example = example.SerializeToString() + + with self.test_session(): + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/bbox/ymin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmin': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/ymax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + 'image/object/bbox/xmax': parsing_ops.FixedLenSequenceFeature( + [], dtypes.float32, allow_missing=True), + } + + items_to_handlers = { + 'object/bbox': + tfexample_decoder.BoundingBox(['ymin', 'xmin', 'ymax', 'xmax'], + 'image/object/bbox/'), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + [tf_bboxes] = decoder.decode(serialized_example, ['object/bbox']) + bboxes = tf_bboxes.eval() + + self.assertAllClose(np_bboxes, bboxes) + def testDecodeExampleWithRepeatedImages(self): image_shape = (2, 3, 3) image_format = 'png' @@ -768,6 +812,87 @@ class TFExampleDecoderTest(test.TestCase): self.assertAllEqual(np.squeeze(output_image[0, :, :, :]), image) self.assertAllEqual(np.squeeze(output_image[1, :, :, :]), image) + def testDecodeExampleWithLookup(self): + + example = example_pb2.Example(features=feature_pb2.Features(feature={ + 'image/object/class/text': self._BytesFeature( + np.array(['cat', 'dog', 'guinea pig'])), + })) + serialized_example = example.SerializeToString() + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + + serialized_example = array_ops.reshape(serialized_example, shape=[]) + + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + } + + items_to_handlers = { + 'labels': + tfexample_decoder.LookupTensor('image/object/class/text', table), + } + + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids = decoder.decode(serialized_example)[0].eval() + + self.assertAllClose([2, 0, 1], obtained_class_ids) + + def testDecodeExampleWithBackupHandlerLookup(self): + + example1 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 900])) + })) + example2 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/text': + self._BytesFeature(np.array(['cat', 'dog', 'guinea pig'])), + })) + example3 = example_pb2.Example( + features=feature_pb2.Features( + feature={ + 'image/object/class/label': + self._EncodedInt64Feature(np.array([42, 10, 901])) + })) + # 'dog' -> 0, 'guinea pig' -> 1, 'cat' -> 2 + table = lookup_ops.index_table_from_tensor( + constant_op.constant(['dog', 'guinea pig', 'cat'])) + keys_to_features = { + 'image/object/class/text': parsing_ops.VarLenFeature(dtypes.string), + 'image/object/class/label': parsing_ops.VarLenFeature(dtypes.int64), + } + backup_handler = tfexample_decoder.BackupHandler( + handler=tfexample_decoder.Tensor('image/object/class/label'), + backup=tfexample_decoder.LookupTensor('image/object/class/text', table)) + items_to_handlers = { + 'labels': backup_handler, + } + decoder = tfexample_decoder.TFExampleDecoder(keys_to_features, + items_to_handlers) + obtained_class_ids_each_example = [] + with self.test_session() as sess: + sess.run(lookup_ops.tables_initializer()) + for example in [example1, example2, example3]: + serialized_example = array_ops.reshape( + example.SerializeToString(), shape=[]) + obtained_class_ids_each_example.append( + decoder.decode(serialized_example)[0].eval()) + + self.assertAllClose([42, 10, 900], obtained_class_ids_each_example[0]) + self.assertAllClose([2, 0, 1], obtained_class_ids_each_example[1]) + self.assertAllClose([42, 10, 901], obtained_class_ids_each_example[2]) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/slim/python/slim/learning.py b/tensorflow/contrib/slim/python/slim/learning.py index 5ee014a1f11a6b0d11857d209f27b134b737275d..def00b76184ba4e1fc630cd83d8e055448100562 100644 --- a/tensorflow/contrib/slim/python/slim/learning.py +++ b/tensorflow/contrib/slim/python/slim/learning.py @@ -552,7 +552,8 @@ def train(train_op, sync_optimizer=None, session_config=None, session_wrapper=None, - trace_every_n_steps=None): + trace_every_n_steps=None, + ignore_live_threads=False): """Runs a training loop using a TensorFlow supervisor. When the sync_optimizer is supplied, gradient updates are applied @@ -615,6 +616,9 @@ def train(train_op, trace_every_n_steps: produce and save a `Timeline` in Chrome trace format and add it to the summaries every `trace_every_n_steps`. If None, no trace information will be produced or saved. + ignore_live_threads: If `True` ignores threads that remain running after + a grace period when stopping the supervisor, instead of raising a + RuntimeError. Returns: the value of the loss function after training. @@ -772,7 +776,10 @@ def train(train_op, if logdir and sv.is_chief: logging.info('Finished training! Saving model to disk.') sv.saver.save(sess, sv.save_path, global_step=sv.global_step) - sv.stop(threads, close_summary_writer=True) + sv.stop( + threads, + close_summary_writer=True, + ignore_live_threads=ignore_live_threads) except errors.AbortedError: # Always re-run on AbortedError as it indicates a restart of one of the diff --git a/tensorflow/contrib/slim/python/slim/nets/BUILD b/tensorflow/contrib/slim/python/slim/nets/BUILD index e2035ab014cfd09682257fbbbf3a2868681aa850..7f03aaf085cf26e3f5f940f4388828006a02ef42 100644 --- a/tensorflow/contrib/slim/python/slim/nets/BUILD +++ b/tensorflow/contrib/slim/python/slim/nets/BUILD @@ -287,25 +287,6 @@ py_test( ], ) -py_test( - name = "resnet_is_training_test", - size = "medium", - srcs = ["resnet_is_training_test.py"], - srcs_version = "PY2AND3", - deps = [ - ":resnet_utils", - ":resnet_v1", - ":resnet_v2", - "//tensorflow/contrib/framework:framework_py", - "//tensorflow/contrib/layers:layers_py", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", - "//third_party/py/numpy", - ], -) - py_library( name = "vgg", srcs = ["vgg.py"], diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py deleted file mode 100644 index 9a165577b699f757057aa10cc14bc1d48c02343a..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_is_training_test.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Specifying is_training in resnet_arg_scope is being deprecated. - -Test that everything behaves as expected in the meantime. - -Note: This test modifies the layers.batch_norm function. -Other tests that use layers.batch_norm may not work if added to this file. -""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib import layers -from tensorflow.contrib.framework.python.ops import add_arg_scope -from tensorflow.contrib.framework.python.ops import arg_scope -from tensorflow.contrib.slim.python.slim.nets import resnet_utils -from tensorflow.contrib.slim.python.slim.nets import resnet_v1 -from tensorflow.contrib.slim.python.slim.nets import resnet_v2 -from tensorflow.python.framework import dtypes -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.platform import test - - -def create_test_input(batch, height, width, channels): - """Create test input tensor.""" - if None in [batch, height, width, channels]: - return array_ops.placeholder(dtypes.float32, (batch, height, width, - channels)) - else: - return math_ops.to_float( - np.tile( - np.reshape( - np.reshape(np.arange(height), [height, 1]) + - np.reshape(np.arange(width), [1, width]), - [1, height, width, 1]), - [batch, 1, 1, channels])) - - -class ResnetIsTrainingTest(test.TestCase): - - def _testDeprecatingIsTraining(self, network_fn): - batch_norm_fn = layers.batch_norm - - @add_arg_scope - def batch_norm_expect_is_training(*args, **kwargs): - assert kwargs['is_training'] - return batch_norm_fn(*args, **kwargs) - - @add_arg_scope - def batch_norm_expect_is_not_training(*args, **kwargs): - assert not kwargs['is_training'] - return batch_norm_fn(*args, **kwargs) - - global_pool = True - num_classes = 10 - inputs = create_test_input(2, 224, 224, 3) - - # Default argument for resnet_arg_scope - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet1') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet2') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope()): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet3') - - # resnet_arg_scope with is_training set to True (deprecated) - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet4') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet5') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=True)): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet6') - - # resnet_arg_scope with is_training set to False (deprecated) - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn(inputs, num_classes, global_pool=global_pool, scope='resnet7') - - layers.batch_norm = batch_norm_expect_is_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn( - inputs, - num_classes, - is_training=True, - global_pool=global_pool, - scope='resnet8') - - layers.batch_norm = batch_norm_expect_is_not_training - with arg_scope(resnet_utils.resnet_arg_scope(is_training=False)): - network_fn( - inputs, - num_classes, - is_training=False, - global_pool=global_pool, - scope='resnet9') - - layers.batch_norm = batch_norm_fn - - def testDeprecatingIsTrainingResnetV1(self): - self._testDeprecatingIsTraining(resnet_v1.resnet_v1_50) - - def testDeprecatingIsTrainingResnetV2(self): - self._testDeprecatingIsTraining(resnet_v2.resnet_v2_50) - - -if __name__ == '__main__': - test.main() diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py index 58614a998abc2a983c4cd8df934cb30090c6443f..cfafee5d8c7a8dd326f6512b9aa224c78ccfb3d4 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_utils.py @@ -41,7 +41,6 @@ from __future__ import print_function import collections from tensorflow.contrib import layers as layers_lib -from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import arg_scope from tensorflow.contrib.layers.python.layers import initializers @@ -223,12 +222,7 @@ def stack_blocks_dense(net, return net -@deprecated_args( - '2017-08-01', - 'Pass is_training directly to the network instead of the arg_scope.', - 'is_training') -def resnet_arg_scope(is_training=True, - weight_decay=0.0001, +def resnet_arg_scope(weight_decay=0.0001, batch_norm_decay=0.997, batch_norm_epsilon=1e-5, batch_norm_scale=True): @@ -240,8 +234,6 @@ def resnet_arg_scope(is_training=True, training ResNets from scratch, they might need to be tuned. Args: - is_training: Whether or not we are training the parameters in the batch - normalization layers of the model. (deprecated) weight_decay: The weight decay to use for regularizing the model. batch_norm_decay: The moving average decay when estimating layer activation statistics in batch normalization. @@ -254,7 +246,6 @@ def resnet_arg_scope(is_training=True, An `arg_scope` to use for the resnet models. """ batch_norm_params = { - 'is_training': is_training, 'decay': batch_norm_decay, 'epsilon': batch_norm_epsilon, 'scale': batch_norm_scale, @@ -266,7 +257,8 @@ def resnet_arg_scope(is_training=True, weights_regularizer=regularizers.l2_regularizer(weight_decay), weights_initializer=initializers.variance_scaling_initializer(), activation_fn=nn_ops.relu, - normalizer_fn=layers.batch_norm): + normalizer_fn=layers.batch_norm, + normalizer_params=batch_norm_params): with arg_scope([layers.batch_norm], **batch_norm_params): # The following implies padding='SAME' for pool1, which makes feature # alignment easier for dense prediction tasks. This is also used in diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py index 90f93d46e34b7554353d74529360d8e9a8ff5d06..235a595de49f956e1df740fd821936c80eefaa55 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1.py @@ -128,7 +128,7 @@ def bottleneck(inputs, def resnet_v1(inputs, blocks, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, @@ -163,8 +163,7 @@ def resnet_v1(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. - is_training: whether is training or not. If None, the value inherited from - the resnet_arg_scope is used. Specifying value None is deprecated. + is_training: whether batch_norm layers are in training mode. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -196,11 +195,7 @@ def resnet_v1(inputs, with arg_scope( [layers.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - if is_training is not None: - bn_scope = arg_scope([layers.batch_norm], is_training=is_training) - else: - bn_scope = arg_scope([]) - with bn_scope: + with arg_scope([layers.batch_norm], is_training=is_training): net = inputs if include_root_block: if output_stride is not None: @@ -255,7 +250,7 @@ def resnet_v1_block(scope, base_depth, num_units, stride): def resnet_v1_50(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -281,7 +276,7 @@ def resnet_v1_50(inputs, def resnet_v1_101(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -307,7 +302,7 @@ def resnet_v1_101(inputs, def resnet_v1_152(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -333,7 +328,7 @@ def resnet_v1_152(inputs, def resnet_v1_200(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py index d510337fef0762e086aee7341d4739393ee165f8..576444214d5edb772addef64d5def84e3915c29b 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v1_test.py @@ -250,7 +250,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, @@ -386,7 +386,7 @@ class ResnetCompleteNetworkTest(test.TestCase): inputs, None, is_training=False, global_pool=False) sess.run(variables.global_variables_initializer()) self.assertAllClose( - output.eval(), expected.eval(), atol=1e-4, rtol=1e-4) + output.eval(), expected.eval(), atol=2e-4, rtol=1e-4) def testUnknownBatchSize(self): batch = 2 diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py index 63e8f1ff356dfcf0427d5170a03faa47ee06298c..61665c9c8ba7817377a16bf3f2673447cab0518e 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2.py @@ -130,7 +130,7 @@ def bottleneck(inputs, def resnet_v2(inputs, blocks, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, @@ -165,8 +165,7 @@ def resnet_v2(inputs, is a resnet_utils.Block object describing the units in the block. num_classes: Number of predicted classes for classification tasks. If None we return the features before the logit layer. - is_training: whether is training or not. If None, the value inherited from - the resnet_arg_scope is used. Specifying value None is deprecated. + is_training: whether batch_norm layers are in training mode. global_pool: If True, we perform global average pooling before computing the logits. Set to True for image classification, False for dense prediction. output_stride: If None, then the output will be computed at the nominal @@ -200,11 +199,7 @@ def resnet_v2(inputs, with arg_scope( [layers_lib.conv2d, bottleneck, resnet_utils.stack_blocks_dense], outputs_collections=end_points_collection): - if is_training is not None: - bn_scope = arg_scope([layers.batch_norm], is_training=is_training) - else: - bn_scope = arg_scope([]) - with bn_scope: + with arg_scope([layers.batch_norm], is_training=is_training): net = inputs if include_root_block: if output_stride is not None: @@ -268,7 +263,7 @@ def resnet_v2_block(scope, base_depth, num_units, stride): def resnet_v2_50(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -294,8 +289,8 @@ def resnet_v2_50(inputs, def resnet_v2_101(inputs, num_classes=None, + is_training=True, global_pool=True, - is_training=None, output_stride=None, reuse=None, scope='resnet_v2_101'): @@ -320,7 +315,7 @@ def resnet_v2_101(inputs, def resnet_v2_152(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, @@ -346,7 +341,7 @@ def resnet_v2_152(inputs, def resnet_v2_200(inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, reuse=None, diff --git a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py index c4f3b071fd940d2c3d7c80fa3041b0426e336ab0..6bdda18c5ba8fe0c9d3374010266c3391044a206 100644 --- a/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py +++ b/tensorflow/contrib/slim/python/slim/nets/resnet_v2_test.py @@ -254,7 +254,7 @@ class ResnetCompleteNetworkTest(test.TestCase): def _resnet_small(self, inputs, num_classes=None, - is_training=None, + is_training=True, global_pool=True, output_stride=None, include_root_block=True, diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD index 865fb72a55b9a83b8354a100af843abaefc79980..6e259e1d32be64f3b593faf73e8af4f704d72349 100644 --- a/tensorflow/contrib/stateless/BUILD +++ b/tensorflow/contrib/stateless/BUILD @@ -21,7 +21,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":stateless_random_ops", - "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py index 9a36bdc2f9558220fa6cc47d5bb95d6e49a480f7..cd4d46aa07bfa92b8243f2f168fd1e4682ad70e2 100644 --- a/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py +++ b/tensorflow/contrib/stateless/python/kernel_tests/stateless_random_ops_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib import stateless +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import random_seed from tensorflow.python.ops import array_ops @@ -79,6 +80,21 @@ class StatelessOpsTest(test.TestCase): for s1, v1 in values: self.assertEqual(s0 == s1, np.all(v0 == v1)) + def testShapeType(self): + with self.test_session(use_gpu=True): + for shape_dtype in [dtypes.int32, dtypes.int64]: + seed_t = array_ops.placeholder(dtypes.int64, shape=[2]) + seeds = [(x, y) for x in range(5) for y in range(5)] * 3 + for stateless_op, _ in CASES: + for shape in (), (3,), (2, 5): + pure = stateless_op(constant_op.constant(shape, dtype=shape_dtype), + seed=seed_t) + values = [(seed, pure.eval(feed_dict={seed_t: seed})) + for seed in seeds] + for s0, v0 in values: + for s1, v1 in values: + self.assertEqual(s0 == s1, np.all(v0 == v1)) + if __name__ == '__main__': test.main() diff --git a/tensorflow/contrib/summary/BUILD b/tensorflow/contrib/summary/BUILD index 527deab86a6ba1e5ccfe6aceb6d73d20aee3ebc2..da23f1c3806be73d43e44bf4b4079d81b2d61c8f 100644 --- a/tensorflow/contrib/summary/BUILD +++ b/tensorflow/contrib/summary/BUILD @@ -13,7 +13,10 @@ load( tf_gen_op_wrapper_py( name = "gen_summary_ops", out = "gen_summary_ops.py", - deps = ["//tensorflow/core:summary_ops_op_lib"], + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:summary_ops_op_lib", + ], ) py_test( @@ -22,9 +25,9 @@ py_test( srcs_version = "PY2AND3", deps = [ ":summary_ops", - "//tensorflow/core:protos_all_py", + ":summary_test_util", + "//tensorflow/python:errors", "//tensorflow/python:framework_test_lib", - "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:training", "//tensorflow/python/eager:function", @@ -39,16 +42,29 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":gen_summary_ops", + "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", + "//tensorflow/python:layers_base", + "//tensorflow/python:resource_variable_ops", "//tensorflow/python:summary_op_util", "//tensorflow/python:training", + "//tensorflow/python:util", "//tensorflow/python/eager:context", ], ) +py_library( + name = "summary", + srcs = ["summary.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + ":summary_ops", + ], +) + filegroup( name = "all_files", srcs = glob( @@ -60,3 +76,17 @@ filegroup( ), visibility = ["//tensorflow:__subpackages__"], ) + +# NOTE: target cannot be testonly because it needs to be in the pip +# package. Sigh. +py_library( + name = "summary_test_util", + srcs = ["summary_test_util.py"], + srcs_version = "PY2AND3", + visibility = ["//tensorflow:internal"], + deps = [ + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", + "//tensorflow/python:platform", + ], +) diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py new file mode 100644 index 0000000000000000000000000000000000000000..ca82ea094c41c15f376e6f6f448b770c5cf291d7 --- /dev/null +++ b/tensorflow/contrib/summary/summary.py @@ -0,0 +1,40 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Contrib summary package. + +The operations in this package are safe to use with eager execution turned or on +off. + +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# pylint: disable=unused-import +from tensorflow.contrib.summary.summary_ops import all_summary_ops +from tensorflow.contrib.summary.summary_ops import always_record_summaries +from tensorflow.contrib.summary.summary_ops import audio +from tensorflow.contrib.summary.summary_ops import create_summary_file_writer +from tensorflow.contrib.summary.summary_ops import eval_dir +from tensorflow.contrib.summary.summary_ops import generic +from tensorflow.contrib.summary.summary_ops import histogram +from tensorflow.contrib.summary.summary_ops import image +from tensorflow.contrib.summary.summary_ops import never_record_summaries +from tensorflow.contrib.summary.summary_ops import record_summaries_every_n_global_steps +from tensorflow.contrib.summary.summary_ops import scalar +from tensorflow.contrib.summary.summary_ops import should_record_summaries +from tensorflow.contrib.summary.summary_ops import summary_writer_initializer_op diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py index ceaf83b70a76e8a1195b4c177f4764dc7ab792f2..56e31985936c22d9b5d6c85fff067118152e220d 100644 --- a/tensorflow/contrib/summary/summary_ops.py +++ b/tensorflow/contrib/summary/summary_ops.py @@ -19,26 +19,35 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.contrib.summary import gen_summary_ops from tensorflow.python.eager import context from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.layers import utils +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import summary_op_util from tensorflow.python.training import training_util - +from tensorflow.python.util import tf_contextlib # Name for a collection which is expected to have at most a single boolean # Tensor. If this tensor is True the summary ops will record summaries. _SHOULD_RECORD_SUMMARIES_NAME = "ShouldRecordSummaries" +_SUMMARY_COLLECTION_NAME = "_SUMMARY_V2" +_SUMMARY_WRITER_INIT_COLLECTION_NAME = "_SUMMARY_WRITER_V2" + def should_record_summaries(): """Returns boolean Tensor which is true if summaries should be recorded.""" should_record_collection = ops.get_collection(_SHOULD_RECORD_SUMMARIES_NAME) if not should_record_collection: - return constant_op.constant(False) + return False if len(should_record_collection) != 1: raise ValueError( "More than one tensor specified for whether summaries " @@ -47,45 +56,125 @@ def should_record_summaries(): # TODO(apassos) consider how to handle local step here. +@tf_contextlib.contextmanager def record_summaries_every_n_global_steps(n): """Sets the should_record_summaries Tensor to true if global_step % n == 0.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [training_util.get_global_step() % n == 0] + old = collection_ref[:] + with ops.device("cpu:0"): + collection_ref[:] = [math_ops.equal(training_util.get_global_step() % n, 0)] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def always_record_summaries(): """Sets the should_record_summaries Tensor to always true.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(True)] + old = collection_ref[:] + collection_ref[:] = [True] + yield + collection_ref[:] = old +@tf_contextlib.contextmanager def never_record_summaries(): """Sets the should_record_summaries Tensor to always false.""" collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME) - collection_ref[:] = [constant_op.constant(False)] + old = collection_ref[:] + collection_ref[:] = [False] + yield + collection_ref[:] = old + + +class SummaryWriter(object): + """Encapsulates a summary writer.""" + + def __init__(self, resource): + self._resource = resource + if context.in_eager_mode(): + self._resource_deleter = resource_variable_ops.EagerResourceDeleter( + handle=self._resource, handle_device="cpu:0") + + def set_as_default(self): + context.context().summary_writer_resource = self._resource + + @tf_contextlib.contextmanager + def as_default(self): + if self._resource is None: + yield + else: + old = context.context().summary_writer_resource + context.context().summary_writer_resource = self._resource + yield + # Flushes the summary writer in eager mode or in graph functions, but not + # in legacy graph mode (you're on your own there). + with ops.device("cpu:0"): + gen_summary_ops.flush_summary_writer(self._resource) + context.context().summary_writer_resource = old def create_summary_file_writer(logdir, max_queue=None, - flush_secs=None, + flush_millis=None, filename_suffix=None, name=None): - """Creates a summary file writer in the current context.""" - if max_queue is None: - max_queue = constant_op.constant(10) - if flush_secs is None: - flush_secs = constant_op.constant(120) - if filename_suffix is None: - filename_suffix = constant_op.constant("") - resource = gen_summary_ops.summary_writer(shared_name=name) - gen_summary_ops.create_summary_file_writer(resource, logdir, max_queue, - flush_secs, filename_suffix) - context.context().summary_writer_resource = resource + """Creates a summary file writer in the current context. + + Args: + logdir: a string, or None. If a string, creates a summary file writer + which writes to the directory named by the string. If None, returns + a mock object which acts like a summary writer but does nothing, + useful to use as a context manager. + max_queue: the largest number of summaries to keep in a queue; will + flush once the queue gets bigger than this. + flush_millis: the largest interval between flushes. + filename_suffix: optional suffix for the event file name. + name: name for the summary writer. + + Returns: + Either a summary writer or an empty object which can be used as a + summary writer. + """ + if logdir is None: + return SummaryWriter(None) + with ops.device("cpu:0"): + if max_queue is None: + max_queue = constant_op.constant(10) + if flush_millis is None: + flush_millis = constant_op.constant(2 * 60 * 1000) + if filename_suffix is None: + filename_suffix = constant_op.constant("") + resource = gen_summary_ops.summary_writer(shared_name=name) + # TODO(apassos) ensure the initialization op runs when in graph mode; + # consider calling session.run here. + ops.add_to_collection( + _SUMMARY_WRITER_INIT_COLLECTION_NAME, + gen_summary_ops.create_summary_file_writer( + resource, logdir, max_queue, flush_millis, filename_suffix)) + return SummaryWriter(resource) def _nothing(): """Convenient else branch for when summaries do not record.""" - return False + return constant_op.constant(False) + + +def all_summary_ops(): + """Graph-mode only. Returns all summary ops.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.all_summary_ops is only supported in graph mode.") + return ops.get_collection(_SUMMARY_COLLECTION_NAME) + + +def summary_writer_initializer_op(): + """Graph-mode only. Returns the list of ops to create all summary writers.""" + if context.in_eager_mode(): + raise RuntimeError( + "tf.contrib.summary.summary_writer_initializer_op is only " + "supported in graph mode.") + return ops.get_collection(_SUMMARY_WRITER_INIT_COLLECTION_NAME) def summary_writer_function(name, tensor, function, family=None): @@ -103,19 +192,27 @@ def summary_writer_function(name, tensor, function, family=None): def record(): with summary_op_util.summary_scope( name, family, values=[tensor]) as (tag, scope): - function(tag, scope) - return True + with ops.control_dependencies([function(tag, scope)]): + return constant_op.constant(True) - return control_flow_ops.cond(should_record_summaries(), record, _nothing) + if context.context().summary_writer_resource is None: + return control_flow_ops.no_op() + with ops.device("cpu:0"): + op = utils.smart_cond( + should_record_summaries(), record, _nothing, name="") + ops.add_to_collection(_SUMMARY_COLLECTION_NAME, op) + return op def generic(name, tensor, metadata, family=None): """Writes a tensor summary if possible.""" def function(tag, scope): - gen_summary_ops.write_summary(context.context().summary_writer_resource, - training_util.get_global_step(), tensor, - tag, metadata, name=scope) + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_summary( + context.context().summary_writer_resource, + training_util.get_global_step(), array_ops.identity(tensor), + tag, metadata, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -123,9 +220,11 @@ def scalar(name, tensor, family=None): """Writes a scalar summary if possible.""" def function(tag, scope): - gen_summary_ops.write_scalar_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_scalar_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -134,9 +233,11 @@ def histogram(name, tensor, family=None): """Writes a histogram summary if possible.""" def function(tag, scope): - gen_summary_ops.write_histogram_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_histogram_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -145,12 +246,14 @@ def image(name, tensor, bad_color=None, max_images=3, family=None): """Writes an image summary if possible.""" def function(tag, scope): - if bad_color is None: - bad_color_ = constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) - gen_summary_ops.write_image_summary( + bad_color_ = (constant_op.constant([255, 0, 0, 255], dtype=dtypes.uint8) + if bad_color is None else bad_color) + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_image_summary( context.context().summary_writer_resource, - training_util.get_global_step(), tag, tensor, bad_color_, max_images, - name=scope) + training_util.get_global_step(), tag, array_ops.identity(tensor), + bad_color_, + max_images, name=scope) return summary_writer_function(name, tensor, function, family=family) @@ -159,13 +262,19 @@ def audio(name, tensor, sample_rate, max_outputs, family=None): """Writes an audio summary if possible.""" def function(tag, scope): - gen_summary_ops.write_audio_summary( + # Note the identity to move the tensor to the CPU. + return gen_summary_ops.write_audio_summary( context.context().summary_writer_resource, training_util.get_global_step(), tag, - tensor, + array_ops.identity(tensor), sample_rate=sample_rate, max_outputs=max_outputs, name=scope) return summary_writer_function(name, tensor, function, family=family) + + +def eval_dir(model_dir, name=None): + """Construct a logdir for an eval summary writer.""" + return os.path.join(model_dir, "eval" if not name else "eval_" + name) diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py index c9a9bb3d5b17c309e136f902505bf1fc9e5295aa..de7ae6ec277a97235617882a7cc7e469eaebe26c 100644 --- a/tensorflow/contrib/summary/summary_ops_test.py +++ b/tensorflow/contrib/summary/summary_ops_test.py @@ -17,16 +17,14 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import os import tempfile from tensorflow.contrib.summary import summary_ops -from tensorflow.core.util import event_pb2 +from tensorflow.contrib.summary import summary_test_util from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.framework import errors from tensorflow.python.framework import test_util -from tensorflow.python.lib.io import tf_record from tensorflow.python.platform import gfile from tensorflow.python.training import training_util @@ -40,44 +38,53 @@ class TargetTest(test_util.TensorFlowTestCase): summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') def testShouldRecordSummary(self): - self.assertFalse(summary_ops.should_record_summaries().numpy()) - summary_ops.always_record_summaries() - self.assertTrue(summary_ops.should_record_summaries().numpy()) + self.assertFalse(summary_ops.should_record_summaries()) + with summary_ops.always_record_summaries(): + self.assertTrue(summary_ops.should_record_summaries()) def testSummaryOps(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t0') - summary_ops.always_record_summaries() - summary_ops.generic('tensor', 1, '') - summary_ops.scalar('scalar', 2.0) - summary_ops.histogram('histogram', [1.0]) - summary_ops.image('image', [[[[1.0]]]]) - summary_ops.audio('audio', [[1.0]], 1.0, 1) - # The working condition of the ops is tested in the C++ test so we just - # test here that we're calling them correctly. - self.assertTrue(gfile.Exists(logdir)) + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t0').as_default(), summary_ops.always_record_summaries(): + summary_ops.generic('tensor', 1, '') + summary_ops.scalar('scalar', 2.0) + summary_ops.histogram('histogram', [1.0]) + summary_ops.image('image', [[[[1.0]]]]) + summary_ops.audio('audio', [[1.0]], 1.0, 1) + # The working condition of the ops is tested in the C++ test so we just + # test here that we're calling them correctly. + self.assertTrue(gfile.Exists(logdir)) def testDefunSummarys(self): training_util.get_or_create_global_step() logdir = tempfile.mkdtemp() - summary_ops.create_summary_file_writer(logdir, max_queue=0, name='t1') - summary_ops.always_record_summaries() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t1').as_default(), summary_ops.always_record_summaries(): - @function.defun - def write(): - summary_ops.scalar('scalar', 2.0) + @function.defun + def write(): + summary_ops.scalar('scalar', 2.0) - write() + write() + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].simple_value, 2.0) + + def testSummaryName(self): + training_util.get_or_create_global_step() + logdir = tempfile.mkdtemp() + with summary_ops.create_summary_file_writer( + logdir, max_queue=0, + name='t2').as_default(), summary_ops.always_record_summaries(): + + summary_ops.scalar('scalar', 2.0) - self.assertTrue(gfile.Exists(logdir)) - files = gfile.ListDirectory(logdir) - self.assertEqual(len(files), 1) - records = list(tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) - self.assertEqual(len(records), 2) - event = event_pb2.Event() - event.ParseFromString(records[1]) - self.assertEqual(event.summary.value[0].simple_value, 2.0) + events = summary_test_util.events_from_file(logdir) + self.assertEqual(len(events), 2) + self.assertEqual(events[1].summary.value[0].tag, 'scalar') if __name__ == '__main__': diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..37b546d3ab3220f934ea3bf7ef8f5fe6ab29f683 --- /dev/null +++ b/tensorflow/contrib/summary/summary_test_util.py @@ -0,0 +1,41 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities to test summaries.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from tensorflow.core.util import event_pb2 +from tensorflow.python.lib.io import tf_record +from tensorflow.python.platform import gfile + + +def events_from_file(logdir): + """Returns all events in the single eventfile in logdir.""" + assert gfile.Exists(logdir) + files = gfile.ListDirectory(logdir) + assert len(files) == 1, "Found more than one file in logdir: %s" % files + records = list( + tf_record.tf_record_iterator(os.path.join(logdir, files[0]))) + result = [] + for r in records: + event = event_pb2.Event() + event.ParseFromString(r) + result.append(event) + return result diff --git a/tensorflow/contrib/tensor_forest/BUILD b/tensorflow/contrib/tensor_forest/BUILD index bff7d022740ed8fe0c763865fe20d7cb0efd60d5..878415604e7e2f14a146939e7645932d56d999d0 100644 --- a/tensorflow/contrib/tensor_forest/BUILD +++ b/tensorflow/contrib/tensor_forest/BUILD @@ -269,9 +269,11 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":gen_model_ops_py", - ":stats_ops_py", - "//tensorflow/core:protos_all_py", + "//tensorflow/contrib/util:util_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + "//tensorflow/python:training", ], ) @@ -286,12 +288,10 @@ tf_cc_test( ":forest_proto_impl", ":model_ops_lib", "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource_impl", - "//tensorflow/core", "//tensorflow/core:framework_headers_lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", - "//third_party/eigen3", ], ) @@ -364,8 +364,12 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":gen_stats_ops_py", + "//tensorflow/contrib/util:util_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:platform", + "//tensorflow/python:resources", + "//tensorflow/python:training", ], ) @@ -382,6 +386,7 @@ tf_cc_test( "//tensorflow/contrib/tensor_forest/kernels/v4:decision-tree-resource_impl", "//tensorflow/core", "//tensorflow/core:framework_headers_lib", + "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core:testlib", @@ -495,9 +500,13 @@ py_library( "//tensorflow/contrib/decision_trees/proto:generic_tree_model_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/tensor_forest/proto:tensor_forest_params_proto_py", + "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "@six_archive//:six", ], @@ -524,13 +533,17 @@ py_library( deps = [ ":client_lib", "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/layers:layers_py", "//tensorflow/contrib/learn", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:resource_variable_ops", + "//tensorflow/python:sparse_tensor", "//tensorflow/python:state_ops", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", ], diff --git a/tensorflow/contrib/tensor_forest/hybrid/BUILD b/tensorflow/contrib/tensor_forest/hybrid/BUILD index 13b9749756d60e2a8ecc5e4cbfd3d3a60c496552..a2a3b485f6aa0ae827bbaa7812823730bd8db3b8 100644 --- a/tensorflow/contrib/tensor_forest/hybrid/BUILD +++ b/tensorflow/contrib/tensor_forest/hybrid/BUILD @@ -105,8 +105,8 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":training_ops", + "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", - "//tensorflow/python:framework", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform", @@ -180,7 +180,6 @@ py_test( deps = [ ":ops_lib", ":training_ops", - "//tensorflow:tensorflow_py", "//tensorflow/python:framework_test_lib", "//tensorflow/python:platform_test", ], diff --git a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc index 3d9de006b4778a1446179139d68eadac6704e0c9..b9aad36f3d25b9fb7b8b525be54fb7a39394b373 100644 --- a/tensorflow/contrib/tensor_forest/kernels/model_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/model_ops.cc @@ -169,10 +169,6 @@ class TreePredictionsV4Op : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = - std::unique_ptr(new TensorDataSet(input_spec_, 0)); - model_op_ = LeafModelOperatorFactory::CreateLeafModelOperator(param_proto_); } @@ -182,8 +178,9 @@ class TreePredictionsV4Op : public OpKernel { const Tensor& sparse_input_values = context->input(3); const Tensor& sparse_input_shape = context->input(4); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); DecisionTreeResource* decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), @@ -191,7 +188,7 @@ class TreePredictionsV4Op : public OpKernel { mutex_lock l(*decision_tree_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_resource); - const int num_data = data_set_->NumItems(); + const int num_data = data_set->NumItems(); const int32 num_outputs = param_proto_.num_outputs(); Tensor* output_predictions = nullptr; @@ -208,11 +205,11 @@ class TreePredictionsV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &out, decision_tree_resource, num_data, &tree_paths]( - int64 start, int64 end) { + auto traverse = [this, &out, &data_set, decision_tree_resource, num_data, + &tree_paths](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set_, static_cast(start), + TraverseTree(decision_tree_resource, data_set, static_cast(start), static_cast(end), std::bind(&TreePredictionsV4Op::set_output_value, this, std::placeholders::_1, std::placeholders::_2, @@ -259,7 +256,6 @@ class TreePredictionsV4Op : public OpKernel { private: tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; std::unique_ptr model_op_; TensorForestParams param_proto_; }; @@ -275,9 +271,6 @@ class TraverseTreeV4Op : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = - std::unique_ptr(new TensorDataSet(input_spec_, 0)); } void Compute(OpKernelContext* context) override { @@ -286,8 +279,9 @@ class TraverseTreeV4Op : public OpKernel { const Tensor& sparse_input_values = context->input(3); const Tensor& sparse_input_shape = context->input(4); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); DecisionTreeResource* decision_tree_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), @@ -295,7 +289,7 @@ class TraverseTreeV4Op : public OpKernel { mutex_lock l(*decision_tree_resource->get_mutex()); core::ScopedUnref unref_me(decision_tree_resource); - const int num_data = data_set_->NumItems(); + const int num_data = data_set->NumItems(); Tensor* output_predictions = nullptr; TensorShape output_shape; @@ -310,11 +304,11 @@ class TraverseTreeV4Op : public OpKernel { auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; const int64 costPerTraverse = 500; - auto traverse = [this, &set_leaf_ids, decision_tree_resource, num_data]( - int64 start, int64 end) { + auto traverse = [this, &set_leaf_ids, &data_set, decision_tree_resource, + num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - TraverseTree(decision_tree_resource, data_set_, static_cast(start), + TraverseTree(decision_tree_resource, data_set, static_cast(start), static_cast(end), set_leaf_ids, nullptr); }; Shard(num_threads, worker_threads->workers, num_data, costPerTraverse, @@ -323,7 +317,6 @@ class TraverseTreeV4Op : public OpKernel { private: tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; TensorForestParams param_proto_; }; diff --git a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc index b6d57ef952777bc204f9534e60f2ce7de3687615..f80a34ece662d1e0b0ea1cb7616fa1b5b84731fa 100644 --- a/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc +++ b/tensorflow/contrib/tensor_forest/kernels/stats_ops.cc @@ -235,9 +235,6 @@ class ProcessInputOp : public OpKernel { string serialized_proto; OP_REQUIRES_OK(context, context->GetAttr("input_spec", &serialized_proto)); input_spec_.ParseFromString(serialized_proto); - - data_set_ = std::unique_ptr( - new TensorDataSet(input_spec_, random_seed_)); } void Compute(OpKernelContext* context) override { @@ -249,8 +246,9 @@ class ProcessInputOp : public OpKernel { const Tensor& input_weights = context->input(7); const Tensor& leaf_ids_tensor = context->input(8); - data_set_->set_input_tensors(input_data, sparse_input_indices, - sparse_input_values, sparse_input_shape); + std::unique_ptr data_set(new TensorDataSet(input_spec_, 0)); + data_set->set_input_tensors(input_data, sparse_input_indices, + sparse_input_values, sparse_input_shape); FertileStatsResource* fertile_stats_resource; OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 1), @@ -264,7 +262,7 @@ class ProcessInputOp : public OpKernel { core::ScopedUnref unref_stats(fertile_stats_resource); core::ScopedUnref unref_tree(tree_resource); - const int32 num_data = data_set_->NumItems(); + const int32 num_data = data_set->NumItems(); auto worker_threads = context->device()->tensorflow_cpu_worker_threads(); int num_threads = worker_threads->num_threads; @@ -308,23 +306,23 @@ class ProcessInputOp : public OpKernel { // from a digits run on local desktop. Heuristics might be necessary // if it really matters that much. const int64 costPerUpdate = 1000; - auto update = [this, &target, &leaf_ids_tensor, &num_targets, + auto update = [this, &target, &leaf_ids_tensor, &num_targets, &data_set, fertile_stats_resource, &locks, &set_lock, &ready_to_split, num_data](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_data); - UpdateStats(fertile_stats_resource, data_set_, target, num_targets, + UpdateStats(fertile_stats_resource, data_set, target, num_targets, leaf_ids_tensor, &locks, &set_lock, static_cast(start), static_cast(end), &ready_to_split); }; auto update_collated = [this, &target, &num_targets, fertile_stats_resource, tree_resource, &leaf_examples, &set_lock, - &ready_to_split, + &ready_to_split, &data_set, num_leaves](int64 start, int64 end) { CHECK(start <= end); CHECK(end <= num_leaves); - UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set_, + UpdateStatsCollated(fertile_stats_resource, tree_resource, data_set, target, num_targets, leaf_examples, &set_lock, static_cast(start), static_cast(end), &ready_to_split); @@ -350,7 +348,6 @@ class ProcessInputOp : public OpKernel { private: int32 random_seed_; tensorforest::TensorForestDataSpec input_spec_; - std::unique_ptr data_set_; TensorForestParams param_proto_; }; diff --git a/tensorflow/contrib/tensor_forest/python/tensor_forest.py b/tensorflow/contrib/tensor_forest/python/tensor_forest.py index 756533250a62d9eb01ae9d2c80125272aeabca4c..eb938763f12efd9281bec4321384acd4617cdfcf 100644 --- a/tensorflow/contrib/tensor_forest/python/tensor_forest.py +++ b/tensorflow/contrib/tensor_forest/python/tensor_forest.py @@ -470,7 +470,11 @@ class RandomForestGraphs(object): """Constructs a TF graph for evaluating a random forest. Args: - input_data: A tensor or dict of string->Tensor for input data. + input_data: A tensor or dict of string->Tensor for the input data. + This input_data must generate the same spec as the + input_data used in training_graph: the dict must have + the same keys, for example, and all tensors must have + the same size in their first dimension. **inference_args: Keyword arguments to pass through to each tree. Returns: diff --git a/tensorflow/contrib/tensorboard/db/BUILD b/tensorflow/contrib/tensorboard/db/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..d8bbf87d2cecaec9b612e45e82295cebd3ac4c7f --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/BUILD @@ -0,0 +1,62 @@ +# Description: +# TensorBoard database code. + +package(default_visibility = ["//tensorflow:internal"]) + +licenses(["notice"]) # Apache 2.0 + +load("//tensorflow:tensorflow.bzl", "tf_cc_test") + +cc_library( + name = "schema", + srcs = ["schema.cc"], + hdrs = ["schema.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core/lib/db:sqlite", + ], +) + +tf_cc_test( + name = "schema_test", + srcs = ["schema_test.cc"], + deps = [ + ":schema", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/db:sqlite", + ], +) + +cc_library( + name = "summary_db_writer", + srcs = ["summary_db_writer.cc"], + hdrs = ["summary_db_writer.h"], + deps = [ + ":schema", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/kernels:summary_interface", + "//tensorflow/core/lib/db:sqlite", + ], +) + +tf_cc_test( + name = "summary_db_writer_test", + srcs = ["summary_db_writer_test.cc"], + deps = [ + ":summary_db_writer", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/lib/db:sqlite", + ], +) + +filegroup( + name = "all_files", + srcs = glob(["*"]), + visibility = ["//tensorflow:__pkg__"], +) diff --git a/tensorflow/contrib/tensorboard/db/schema.cc b/tensorflow/contrib/tensorboard/db/schema.cc new file mode 100644 index 0000000000000000000000000000000000000000..98fff9e0ae45279f5734ed2eaac8bf46e8ae4b22 --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/schema.cc @@ -0,0 +1,409 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/schema.h" + +namespace tensorflow { +namespace { + +class SqliteSchema { + public: + explicit SqliteSchema(std::shared_ptr db) : db_(std::move(db)) {} + + /// \brief Creates Tensors table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: ID of associated Tag. + /// computed_time: Float UNIX timestamp with microsecond precision. + /// In the old summaries system that uses FileWriter, this is the + /// wall time around when tf.Session.run finished. In the new + /// summaries system, it is the wall time of when the tensor was + /// computed. On systems with monotonic clocks, it is calculated + /// by adding the monotonic run duration to Run.started_time. + /// This field is not indexed because, in practice, it should be + /// ordered the same or nearly the same as TensorIndex, so local + /// insertion sort might be more suitable. + /// step: User-supplied number, ordering this tensor in Tag. + /// If NULL then the Tag must have only one Tensor. + /// tensor: Can be an INTEGER (DT_INT64), FLOAT (DT_DOUBLE), or + /// BLOB. The structure of a BLOB is currently undefined, but in + /// essence it is a Snappy tf.TensorProto that spills over into + /// TensorChunks. + Status CreateTensorsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Tensors ( + rowid INTEGER PRIMARY KEY, + tag_id INTEGER NOT NULL, + computed_time REAL, + step INTEGER, + tensor BLOB + ) + )sql"); + } + + /// \brief Creates TensorChunks table. + /// + /// This table can be used to split up a tensor across many rows, + /// which has the advantage of not slowing down table scans on the + /// main table, allowing asynchronous fetching, minimizing copying, + /// and preventing large buffers from being allocated. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: ID of associated Tag. + /// step: Same as corresponding Tensors.step. + /// sequence: 1-indexed sequence number for ordering chunks. Please + /// note that the 0th index is Tensors.tensor. + /// chunk: Bytes of next chunk in tensor. + Status CreateTensorChunksTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS TensorChunks ( + rowid INTEGER PRIMARY KEY, + tag_id INTEGER NOT NULL, + step INTEGER, + sequence INTEGER, + chunk BLOB + ) + )sql"); + } + + /// \brief Creates Tags table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// tag_id: Permanent >0 unique ID. + /// run_id: Optional ID of associated Run. + /// tag_name: The tag field in summary.proto, unique across Run. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the wall time of when the row was inserted into the + /// DB. It may be used as a hint for an archival job. + /// metadata: Optional BLOB of SummaryMetadata proto. + /// display_name: Optional for GUI and defaults to tag_name. + /// summary_description: Optional markdown information. + Status CreateTagsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Tags ( + rowid INTEGER PRIMARY KEY, + run_id INTEGER, + tag_id INTEGER NOT NULL, + tag_name TEXT, + inserted_time DOUBLE, + metadata BLOB, + display_name TEXT, + description TEXT + ) + )sql"); + } + + /// \brief Creates Runs table. + /// + /// This table stores information about runs. Each row usually + /// represents a single attempt at training or testing a TensorFlow + /// model, with a given set of hyper-parameters, whose summaries are + /// written out to a single event logs directory with a monotonic step + /// counter. + /// + /// When a run is deleted from this table, TensorBoard should treat all + /// information associated with it as deleted, even if those rows in + /// different tables still exist. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// run_id: Permanent >0 unique ID. + /// experiment_id: Optional ID of associated Experiment. + /// run_name: User-supplied string, unique across Experiment. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + /// started_time: Float UNIX timestamp with µs precision. In the + /// old summaries system that uses FileWriter, this is + /// approximated as the first tf.Event.wall_time. In the new + /// summaries system, it is the wall time of when summary writing + /// started, from the perspective of whichever machine talks to + /// the database. This field will be mutated if the run is + /// restarted. + /// description: Optional markdown information. + /// graph: Snappy tf.GraphDef proto with node field cleared. That + /// field can be recreated using GraphNodes and NodeDefs. + Status CreateRunsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Runs ( + rowid INTEGER PRIMARY KEY, + experiment_id INTEGER, + run_id INTEGER NOT NULL, + run_name TEXT, + inserted_time REAL, + started_time REAL, + description TEXT, + graph BLOB + ) + )sql"); + } + + /// \brief Creates Experiments table. + /// + /// This table stores information about experiments, which are sets of + /// runs. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// user_id: Optional ID of associated User. + /// experiment_id: Permanent >0 unique ID. + /// experiment_name: User-supplied string, unique across User. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + /// started_time: Float UNIX timestamp with µs precision. This is + /// the MIN(experiment.started_time, run.started_time) of each + /// Run added to the database. + /// description: Optional markdown information. + Status CreateExperimentsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Experiments ( + rowid INTEGER PRIMARY KEY, + user_id INTEGER, + experiment_id INTEGER NOT NULL, + experiment_name TEXT, + inserted_time REAL, + started_time REAL, + description TEXT + ) + )sql"); + } + + /// \brief Creates Users table. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// user_id: Permanent >0 unique ID. + /// user_name: Unique user name. + /// email: Optional unique email address. + /// inserted_time: Float UNIX timestamp with µs precision. This is + /// always the time the row was inserted into the database. It + /// does not change. + Status CreateUsersTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS Users ( + rowid INTEGER PRIMARY KEY, + user_id INTEGER NOT NULL, + user_name TEXT, + email TEXT, + inserted_time REAL + ) + )sql"); + } + + /// \brief Creates NodeDefs table. + /// + /// This table stores NodeDef protos which define the GraphDef for a + /// Run. This functions like a hash table so rows can be shared by + /// multiple Runs in an Experiment. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// experiment_id: Optional int64 for grouping rows. + /// node_def_id: Permanent >0 unique ID. + /// fingerprint: Optional farmhash::Fingerprint64() of uncompressed + /// node_def bytes, coerced to int64. + /// node_def: BLOB containing a Snappy tf.NodeDef proto. + Status CreateNodeDefsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS NodeDefs ( + rowid INTEGER PRIMARY KEY, + experiment_id INTEGER, + node_def_id INTEGER NOT NULL, + fingerprint INTEGER, + node_def TEXT + ) + )sql"); + } + + /// \brief Creates RunNodeDefs table. + /// + /// Table mapping Runs to NodeDefs. This is used to recreate the node + /// field of the GraphDef proto. + /// + /// Fields: + /// rowid: Ephemeral b-tree ID dictating locality. + /// run_id: Mandatory ID of associated Run. + /// node_def_id: Mandatory ID of associated NodeDef. + Status CreateRunNodeDefsTable() { + return Run(R"sql( + CREATE TABLE IF NOT EXISTS RunNodeDefs ( + rowid INTEGER PRIMARY KEY, + run_id INTEGER NOT NULL, + node_def_id INTEGER NOT NULL + ) + )sql"); + } + + /// \brief Uniquely indexes (tag_id, step) on Tensors table. + Status CreateTensorIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TensorIndex + ON Tensors (tag_id, step) + )sql"); + } + + /// \brief Uniquely indexes (tag_id, step, sequence) on TensorChunks table. + Status CreateTensorChunkIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TensorChunkIndex + ON TensorChunks (tag_id, step, sequence) + )sql"); + } + + /// \brief Uniquely indexes tag_id on Tags table. + Status CreateTagIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TagIdIndex + ON Tags (tag_id) + )sql"); + } + + /// \brief Uniquely indexes run_id on Runs table. + Status CreateRunIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunIdIndex + ON Runs (run_id) + )sql"); + } + + /// \brief Uniquely indexes experiment_id on Experiments table. + Status CreateExperimentIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS ExperimentIdIndex + ON Experiments (experiment_id) + )sql"); + } + + /// \brief Uniquely indexes user_id on Users table. + Status CreateUserIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserIdIndex + ON Users (user_id) + )sql"); + } + + /// \brief Uniquely indexes node_def_id on NodeDefs table. + Status CreateNodeDefIdIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS NodeDefIdIndex + ON NodeDefs (node_def_id) + )sql"); + } + + /// \brief Uniquely indexes (run_id, tag_name) on Tags table. + Status CreateTagNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS TagNameIndex + ON Tags (run_id, tag_name) + WHERE tag_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (experiment_id, run_name) on Runs table. + Status CreateRunNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunNameIndex + ON Runs (experiment_id, run_name) + WHERE run_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (user_id, experiment_name) on Experiments table. + Status CreateExperimentNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS ExperimentNameIndex + ON Experiments (user_id, experiment_name) + WHERE experiment_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes user_name on Users table. + Status CreateUserNameIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserNameIndex + ON Users (user_name) + WHERE user_name IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes email on Users table. + Status CreateUserEmailIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS UserEmailIndex + ON Users (email) + WHERE email IS NOT NULL + )sql"); + } + + /// \brief Indexes (experiment_id, fingerprint) on NodeDefs table. + Status CreateNodeDefFingerprintIndex() { + return Run(R"sql( + CREATE INDEX IF NOT EXISTS NodeDefFingerprintIndex + ON NodeDefs (experiment_id, fingerprint) + WHERE fingerprint IS NOT NULL + )sql"); + } + + /// \brief Uniquely indexes (run_id, node_def_id) on RunNodeDefs table. + Status CreateRunNodeDefIndex() { + return Run(R"sql( + CREATE UNIQUE INDEX IF NOT EXISTS RunNodeDefIndex + ON RunNodeDefs (run_id, node_def_id) + )sql"); + } + + Status Run(const char* sql) { + auto stmt = db_->Prepare(sql); + TF_RETURN_WITH_CONTEXT_IF_ERROR(stmt.StepAndReset(), sql); + return Status::OK(); + } + + private: + std::shared_ptr db_; +}; + +} // namespace + +Status SetupTensorboardSqliteDb(std::shared_ptr db) { + SqliteSchema s(std::move(db)); + TF_RETURN_IF_ERROR(s.CreateTensorsTable()); + TF_RETURN_IF_ERROR(s.CreateTensorChunksTable()); + TF_RETURN_IF_ERROR(s.CreateTagsTable()); + TF_RETURN_IF_ERROR(s.CreateRunsTable()); + TF_RETURN_IF_ERROR(s.CreateExperimentsTable()); + TF_RETURN_IF_ERROR(s.CreateUsersTable()); + TF_RETURN_IF_ERROR(s.CreateNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateRunNodeDefsTable()); + TF_RETURN_IF_ERROR(s.CreateTensorIndex()); + TF_RETURN_IF_ERROR(s.CreateTensorChunkIndex()); + TF_RETURN_IF_ERROR(s.CreateTagIdIndex()); + TF_RETURN_IF_ERROR(s.CreateRunIdIndex()); + TF_RETURN_IF_ERROR(s.CreateExperimentIdIndex()); + TF_RETURN_IF_ERROR(s.CreateUserIdIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeDefIdIndex()); + TF_RETURN_IF_ERROR(s.CreateTagNameIndex()); + TF_RETURN_IF_ERROR(s.CreateRunNameIndex()); + TF_RETURN_IF_ERROR(s.CreateExperimentNameIndex()); + TF_RETURN_IF_ERROR(s.CreateUserNameIndex()); + TF_RETURN_IF_ERROR(s.CreateUserEmailIndex()); + TF_RETURN_IF_ERROR(s.CreateNodeDefFingerprintIndex()); + TF_RETURN_IF_ERROR(s.CreateRunNodeDefIndex()); + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/db/schema.h b/tensorflow/contrib/tensorboard/db/schema.h new file mode 100644 index 0000000000000000000000000000000000000000..900c10298ce0a69b92f7528db9742517243c3c51 --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/schema.h @@ -0,0 +1,33 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ +#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ + +#include + +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/db/sqlite.h" + +namespace tensorflow { + +/// \brief Creates TensorBoard SQLite tables and indexes. +/// +/// If they are already created, this has no effect. If schema +/// migrations are necessary, they will be performed with logging. +Status SetupTensorboardSqliteDb(std::shared_ptr db); + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SCHEMA_H_ diff --git a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java b/tensorflow/contrib/tensorboard/db/schema_test.cc similarity index 63% rename from tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java rename to tensorflow/contrib/tensorboard/db/schema_test.cc index 49e5d9f2f3a6627201dd9af67b5698f095a9c0f0..463c4e59e7e76e6460b7ddfbd92262ac249aa9ed 100644 --- a/tensorflow/java/src/main/java/org/tensorflow/types/TFDouble.java +++ b/tensorflow/contrib/tensorboard/db/schema_test.cc @@ -12,19 +12,20 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// GENERATED FILE. To update, edit tftypes.pl instead. +#include "tensorflow/contrib/tensorboard/db/schema.h" -package org.tensorflow.types; +#include -import org.tensorflow.DataType; +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/test.h" -/** Represents a 64-bit double precision floating point number. */ -public class TFDouble implements TFType { - private TFDouble() {} - static { - Types.typeCodes.put(TFDouble.class, DataType.DOUBLE); - } - static { - Types.scalars.put(TFDouble.class, 0.0); - } +namespace tensorflow { +namespace { + +TEST(SchemaTest, SmokeTestTensorboardSchema) { + auto db = Sqlite::Open(":memory:").ValueOrDie(); + TF_ASSERT_OK(SetupTensorboardSqliteDb(db)); } + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc new file mode 100644 index 0000000000000000000000000000000000000000..df64e36305529a67f9573e9d26cc0dfc506d324f --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.cc @@ -0,0 +1,279 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" + +#include "tensorflow/contrib/tensorboard/db/schema.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/lib/random/random.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/snappy.h" + +namespace tensorflow { +namespace { + +int64 MakeRandomId() { + int64 id = static_cast(random::New64() & ((1ULL << 63) - 1)); + if (id == 0) { + ++id; + } + return id; +} + +class SummaryDbWriter : public SummaryWriterInterface { + public: + SummaryDbWriter(Env* env, std::shared_ptr db) + : SummaryWriterInterface(), env_(env), db_(std::move(db)), run_id_(-1) {} + ~SummaryDbWriter() override {} + + Status Initialize(const string& experiment_name, const string& run_name, + const string& user_name) { + mutex_lock ml(mu_); + insert_tensor_ = db_->Prepare(R"sql( + INSERT OR REPLACE INTO Tensors (tag_id, step, computed_time, tensor) + VALUES (?, ?, ?, ?) + )sql"); + update_metadata_ = db_->Prepare(R"sql( + UPDATE Tags SET metadata = ? WHERE tag_id = ? + )sql"); + experiment_name_ = experiment_name; + run_name_ = run_name; + user_name_ = user_name; + return Status::OK(); + } + + // TODO(@jart): Use transactions that COMMIT on Flush() + // TODO(@jart): Retry Commit() on SQLITE_BUSY with exponential back-off. + Status Flush() override { return Status::OK(); } + + Status WriteTensor(int64 global_step, Tensor t, const string& tag, + const string& serialized_metadata) override { + mutex_lock ml(mu_); + TF_RETURN_IF_ERROR(InitializeParents()); + // TODO(@jart): Memoize tag_id. + int64 tag_id; + TF_RETURN_IF_ERROR(GetTagId(run_id_, tag, &tag_id)); + if (!serialized_metadata.empty()) { + // TODO(@jart): Only update metadata for first tensor. + update_metadata_.BindBlobUnsafe(1, serialized_metadata); + update_metadata_.BindInt(2, tag_id); + TF_RETURN_IF_ERROR(update_metadata_.StepAndReset()); + } + // TODO(@jart): Lease blocks of rowids and *_ids to minimize fragmentation. + // TODO(@jart): Check for random ID collisions without needing txn retry. + insert_tensor_.BindInt(1, tag_id); + insert_tensor_.BindInt(2, global_step); + insert_tensor_.BindDouble(3, GetWallTime()); + switch (t.dtype()) { + case DT_INT64: + insert_tensor_.BindInt(4, t.scalar()()); + break; + case DT_DOUBLE: + insert_tensor_.BindDouble(4, t.scalar()()); + break; + default: + TF_RETURN_IF_ERROR(BindTensor(t)); + break; + } + TF_RETURN_IF_ERROR(insert_tensor_.StepAndReset()); + return Status::OK(); + } + + Status WriteEvent(std::unique_ptr e) override { + // TODO(@jart): This will be used to load event logs. + return errors::Unimplemented("WriteEvent"); + } + + Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { + // TODO(@jart): Unlike WriteTensor, this method would be granted leniency + // to change the dtype if it saves storage space. For example, + // DT_UINT32 would be stored in the database as an INTEGER + // rather than a serialized BLOB. But when reading it back, + // the dtype would become DT_INT64. + return errors::Unimplemented("WriteScalar"); + } + + Status WriteHistogram(int64 global_step, Tensor t, + const string& tag) override { + return errors::Unimplemented( + "SummaryDbWriter::WriteHistogram not supported. Please use ", + "tensorboard.summary.histogram() instead."); + } + + Status WriteImage(int64 global_step, Tensor tensor, const string& tag, + int max_images, Tensor bad_color) override { + return errors::Unimplemented( + "SummaryDbWriter::WriteImage not supported. Please use ", + "tensorboard.summary.image() instead."); + } + + Status WriteAudio(int64 global_step, Tensor tensor, const string& tag, + int max_outputs, float sample_rate) override { + return errors::Unimplemented( + "SummaryDbWriter::WriteAudio not supported. Please use ", + "tensorboard.summary.audio() instead."); + } + + string DebugString() override { return "SummaryDbWriter"; } + + private: + double GetWallTime() { + // TODO(@jart): Follow precise definitions for time laid out in schema. + // TODO(@jart): Use monotonic clock from gRPC codebase. + return static_cast(env_->NowMicros()) / 1.0e6; + } + + Status BindTensor(const Tensor& t) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // TODO(@jart): Make portable between little and big endian systems. + // TODO(@jart): Use TensorChunks with minimal copying for big tensors. + TensorProto p; + t.AsProtoTensorContent(&p); + string encoded; + if (!p.SerializeToString(&encoded)) { + return errors::DataLoss("SerializeToString failed"); + } + // TODO(@jart): Put byte at beginning of blob to indicate encoding. + // TODO(@jart): Allow crunch tool to re-compress with zlib instead. + string compressed; + if (!port::Snappy_Compress(encoded.data(), encoded.size(), &compressed)) { + return errors::FailedPrecondition("TensorBase needs Snappy"); + } + insert_tensor_.BindBlobUnsafe(4, compressed); + return Status::OK(); + } + + Status InitializeParents() EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (run_id_ >= 0) { + return Status::OK(); + } + int64 user_id; + TF_RETURN_IF_ERROR(GetUserId(user_name_, &user_id)); + int64 experiment_id; + TF_RETURN_IF_ERROR( + GetExperimentId(user_id, experiment_name_, &experiment_id)); + TF_RETURN_IF_ERROR(GetRunId(experiment_id, run_name_, &run_id_)); + return Status::OK(); + } + + Status GetUserId(const string& user_name, int64* user_id) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (user_name.empty()) { + *user_id = 0LL; + return Status::OK(); + } + SqliteStatement get_user_id = db_->Prepare(R"sql( + SELECT user_id FROM Users WHERE user_name = ? + )sql"); + get_user_id.BindText(1, user_name); + bool is_done; + TF_RETURN_IF_ERROR(get_user_id.Step(&is_done)); + if (!is_done) { + *user_id = get_user_id.ColumnInt(0); + } else { + *user_id = MakeRandomId(); + SqliteStatement insert_user = db_->Prepare(R"sql( + INSERT INTO Users (user_id, user_name, inserted_time) VALUES (?, ?, ?) + )sql"); + insert_user.BindInt(1, *user_id); + insert_user.BindText(2, user_name); + insert_user.BindDouble(3, GetWallTime()); + TF_RETURN_IF_ERROR(insert_user.StepAndReset()); + } + return Status::OK(); + } + + Status GetExperimentId(int64 user_id, const string& experiment_name, + int64* experiment_id) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // TODO(@jart): Compute started_time. + return GetId("Experiments", "user_id", user_id, "experiment_name", + experiment_name, "experiment_id", experiment_id); + } + + Status GetRunId(int64 experiment_id, const string& run_name, int64* run_id) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + // TODO(@jart): Compute started_time. + return GetId("Runs", "experiment_id", experiment_id, "run_name", run_name, + "run_id", run_id); + } + + Status GetTagId(int64 run_id, const string& tag_name, int64* tag_id) + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return GetId("Tags", "run_id", run_id, "tag_name", tag_name, "tag_id", + tag_id); + } + + Status GetId(const char* table, const char* parent_id_field, int64 parent_id, + const char* name_field, const string& name, const char* id_field, + int64* id) EXCLUSIVE_LOCKS_REQUIRED(mu_) { + if (name.empty()) { + *id = 0LL; + return Status::OK(); + } + SqliteStatement select = db_->Prepare( + strings::Printf("SELECT %s FROM %s WHERE %s = ? AND %s = ?", id_field, + table, parent_id_field, name_field)); + if (parent_id > 0) { + select.BindInt(1, parent_id); + } + select.BindText(2, name); + bool is_done; + TF_RETURN_IF_ERROR(select.Step(&is_done)); + if (!is_done) { + *id = select.ColumnInt(0); + } else { + *id = MakeRandomId(); + SqliteStatement insert = db_->Prepare(strings::Printf( + "INSERT INTO %s (%s, %s, %s, inserted_time) VALUES (?, ?, ?, ?)", + table, parent_id_field, id_field, name_field)); + if (parent_id > 0) { + insert.BindInt(1, parent_id); + } + insert.BindInt(2, *id); + insert.BindText(3, name); + insert.BindDouble(4, GetWallTime()); + TF_RETURN_IF_ERROR(insert.StepAndReset()); + } + return Status::OK(); + } + + mutex mu_; + Env* env_; + std::shared_ptr db_ GUARDED_BY(mu_); + SqliteStatement insert_tensor_ GUARDED_BY(mu_); + SqliteStatement update_metadata_ GUARDED_BY(mu_); + string user_name_ GUARDED_BY(mu_); + string experiment_name_ GUARDED_BY(mu_); + string run_name_ GUARDED_BY(mu_); + int64 run_id_ GUARDED_BY(mu_); +}; + +} // namespace + +Status CreateSummaryDbWriter(std::shared_ptr db, + const string& experiment_name, + const string& run_name, const string& user_name, + Env* env, SummaryWriterInterface** result) { + TF_RETURN_IF_ERROR(SetupTensorboardSqliteDb(db)); + SummaryDbWriter* w = new SummaryDbWriter(env, std::move(db)); + const Status s = w->Initialize(experiment_name, run_name, user_name); + if (!s.ok()) { + w->Unref(); + *result = nullptr; + return s; + } + *result = w; + return Status::OK(); +} + +} // namespace tensorflow diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer.h b/tensorflow/contrib/tensorboard/db/summary_db_writer.h new file mode 100644 index 0000000000000000000000000000000000000000..74f61e50b7cdf4b4151162a2e1e5e0af0d468be2 --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer.h @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_DB_WRITER_H_ +#define TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_DB_WRITER_H_ + +#include "tensorflow/core/kernels/summary_interface.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { + +/// \brief Creates SQLite SummaryWriterInterface. +/// +/// This can be used to write tensors from the execution graph directly +/// to a database. The schema will be created automatically, but only +/// if necessary. Entries in the Users, Experiments, and Runs tables +/// will be created automatically if they don't already exist. +/// +/// Please note that the type signature of this function may change in +/// the future if support for other DBs is added to core. +Status CreateSummaryDbWriter(std::shared_ptr db, + const string& experiment_name, + const string& run_name, const string& user_name, + Env* env, SummaryWriterInterface** result); + +} // namespace tensorflow + +#endif // TENSORFLOW_CONTRIB_TENSORBOARD_DB_SUMMARY_DB_WRITER_H_ diff --git a/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d32904f97c4172ded51a00dc076630b598494716 --- /dev/null +++ b/tensorflow/contrib/tensorboard/db/summary_db_writer_test.cc @@ -0,0 +1,162 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/contrib/tensorboard/db/summary_db_writer.h" + +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/db/sqlite.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +Tensor MakeScalarInt64(int64 x) { + Tensor t(DT_INT64, TensorShape({})); + t.scalar()() = x; + return t; +} + +class FakeClockEnv : public EnvWrapper { + public: + FakeClockEnv() : EnvWrapper(Env::Default()), current_millis_(0) {} + void AdvanceByMillis(const uint64 millis) { current_millis_ += millis; } + uint64 NowMicros() override { return current_millis_ * 1000; } + uint64 NowSeconds() override { return current_millis_ * 1000; } + + private: + uint64 current_millis_; +}; + +class SummaryDbWriterTest : public ::testing::Test { + protected: + void SetUp() override { db_ = Sqlite::Open("file::memory:").ValueOrDie(); } + + void TearDown() override { + if (writer_ != nullptr) { + writer_->Unref(); + writer_ = nullptr; + } + } + + int64 QueryInt(const string& sql) { + SqliteStatement stmt = db_->Prepare(sql); + bool is_done; + Status s = stmt.Step(&is_done); + if (!s.ok() || is_done) { + LOG(ERROR) << s << " due to " << sql; + return -1; + } + return stmt.ColumnInt(0); + } + + double QueryDouble(const string& sql) { + SqliteStatement stmt = db_->Prepare(sql); + bool is_done; + Status s = stmt.Step(&is_done); + if (!s.ok() || is_done) { + LOG(ERROR) << s << " due to " << sql; + return -1; + } + return stmt.ColumnDouble(0); + } + + string QueryString(const string& sql) { + SqliteStatement stmt = db_->Prepare(sql); + bool is_done; + Status s = stmt.Step(&is_done); + if (!s.ok() || is_done) { + LOG(ERROR) << s << " due to " << sql; + return "MISSINGNO"; + } + return stmt.ColumnString(0); + } + + FakeClockEnv env_; + std::shared_ptr db_; + SummaryWriterInterface* writer_ = nullptr; +}; + +TEST_F(SummaryDbWriterTest, NothingWritten_NoRowsCreated) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, + &writer_)); + TF_ASSERT_OK(writer_->Flush()); + writer_->Unref(); + writer_ = nullptr; + EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Users")); + EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Experiments")); + EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Runs")); + EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tags")); + EXPECT_EQ(0LL, QueryInt("SELECT COUNT(*) FROM Tensors")); +} + +TEST_F(SummaryDbWriterTest, TensorsWritten_RowsGetInitialized) { + TF_ASSERT_OK(CreateSummaryDbWriter(db_, "mad-science", "train", "jart", &env_, + &writer_)); + env_.AdvanceByMillis(23); + TF_ASSERT_OK(writer_->WriteTensor(1, MakeScalarInt64(123LL), "taggy", + "this-is-metaaa")); + env_.AdvanceByMillis(23); + TF_ASSERT_OK(writer_->WriteTensor(2, MakeScalarInt64(314LL), "taggy", "")); + TF_ASSERT_OK(writer_->Flush()); + + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Users")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Experiments")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Runs")); + ASSERT_EQ(1LL, QueryInt("SELECT COUNT(*) FROM Tags")); + ASSERT_EQ(2LL, QueryInt("SELECT COUNT(*) FROM Tensors")); + + int64 user_id = QueryInt("SELECT user_id FROM Users"); + int64 experiment_id = QueryInt("SELECT experiment_id FROM Experiments"); + int64 run_id = QueryInt("SELECT run_id FROM Runs"); + int64 tag_id = QueryInt("SELECT tag_id FROM Tags"); + EXPECT_LT(0LL, user_id); + EXPECT_LT(0LL, experiment_id); + EXPECT_LT(0LL, run_id); + EXPECT_LT(0LL, tag_id); + + EXPECT_EQ("jart", QueryString("SELECT user_name FROM Users")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Users")); + + EXPECT_EQ(user_id, QueryInt("SELECT user_id FROM Experiments")); + EXPECT_EQ("mad-science", + QueryString("SELECT experiment_name FROM Experiments")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Experiments")); + + EXPECT_EQ(experiment_id, QueryInt("SELECT experiment_id FROM Runs")); + EXPECT_EQ("train", QueryString("SELECT run_name FROM Runs")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Runs")); + + EXPECT_EQ(run_id, QueryInt("SELECT run_id FROM Tags")); + EXPECT_EQ("taggy", QueryString("SELECT tag_name FROM Tags")); + EXPECT_EQ(0.023, QueryDouble("SELECT inserted_time FROM Tags")); + EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags")); + + EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 1")); + EXPECT_EQ(0.023, + QueryDouble("SELECT computed_time FROM Tensors WHERE step = 1")); + EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags")); + EXPECT_FALSE( + QueryString("SELECT tensor FROM Tensors WHERE step = 1").empty()); + + EXPECT_EQ(tag_id, QueryInt("SELECT tag_id FROM Tensors WHERE step = 2")); + EXPECT_EQ(0.046, + QueryDouble("SELECT computed_time FROM Tensors WHERE step = 2")); + EXPECT_EQ("this-is-metaaa", QueryString("SELECT metadata FROM Tags")); + EXPECT_FALSE( + QueryString("SELECT tensor FROM Tensors WHERE step = 2").empty()); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/contrib/text/BUILD b/tensorflow/contrib/text/BUILD index 8a2cb28684fe5151176b00fbcfaa64626ec18c38..698fdd830f57eb64c3c4119371f545908bf726e5 100644 --- a/tensorflow/contrib/text/BUILD +++ b/tensorflow/contrib/text/BUILD @@ -36,15 +36,21 @@ tf_custom_op_py_library( srcs_version = "PY2AND3", deps = [ ":gen_skip_gram_ops", + "//tensorflow/contrib/lookup:lookup_py", "//tensorflow/contrib/util:util_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", "//tensorflow/python:framework", + "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:ops", "//tensorflow/python:platform", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", "//tensorflow/python:training", + "//tensorflow/python:util", ], ) diff --git a/tensorflow/contrib/timeseries/examples/BUILD b/tensorflow/contrib/timeseries/examples/BUILD index 015d0eba29f281d78ed6717271987cf3f2e121e9..755b0657e9fb29c167911407cee340ac7e3e9b7a 100644 --- a/tensorflow/contrib/timeseries/examples/BUILD +++ b/tensorflow/contrib/timeseries/examples/BUILD @@ -25,6 +25,7 @@ py_test( srcs = ["predict_test.py"], data = ["data/period_trend.csv"], srcs_version = "PY2AND3", + tags = ["notsan"], # b/67513579 deps = [ ":predict", "//tensorflow/python:client_testlib", @@ -87,6 +88,8 @@ py_binary( tags = ["no_pip"], deps = [ "//tensorflow:tensorflow_py", + "//tensorflow/contrib/timeseries/python/timeseries:estimators", + "//tensorflow/contrib/timeseries/python/timeseries:model", "//third_party/py/numpy", ], ) @@ -96,7 +99,11 @@ py_test( timeout = "long", # Moderate but for asan srcs = ["lstm_test.py"], srcs_version = "PY2AND3", - deps = [":lstm"], + tags = ["notsan"], + deps = [ + ":lstm", + "//tensorflow/python:client_testlib", + ], ) filegroup( diff --git a/tensorflow/contrib/timeseries/examples/lstm.py b/tensorflow/contrib/timeseries/examples/lstm.py index 6bab06f56c859705597027369147643a43ce01c0..3ba823f638da8f750981bc910d960706ff652fb7 100644 --- a/tensorflow/contrib/timeseries/examples/lstm.py +++ b/tensorflow/contrib/timeseries/examples/lstm.py @@ -106,16 +106,6 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): for state_element in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)]) - def _transform(self, data): - """Normalize data based on input statistics to encourage stable training.""" - mean, variance = self._input_statistics.overall_feature_moments - return (data - mean) / variance - - def _de_transform(self, data): - """Transform data back to the input scale.""" - mean, variance = self._input_statistics.overall_feature_moments - return data * variance + mean - def _filtering_step(self, current_times, current_values, state, predictions): """Update model state based on observations. @@ -140,7 +130,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): state_from_time, prediction, lstm_state = state with tf.control_dependencies( [tf.assert_equal(current_times, state_from_time)]): - transformed_values = self._transform(current_values) + # Subtract the mean and divide by the variance of the series. Slightly + # more efficient if done for a whole window (using the normalize_features + # argument to SequentialTimeSeriesModel). + transformed_values = self._scale_data(current_values) # Use mean squared error across features for the loss. predictions["loss"] = tf.reduce_mean( (prediction - transformed_values) ** 2, axis=-1) @@ -156,7 +149,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel): inputs=previous_observation_or_prediction, state=lstm_state) next_prediction = self._predict_from_lstm_output(lstm_output) new_state_tuple = (current_times, next_prediction, new_lstm_state) - return new_state_tuple, {"mean": self._de_transform(next_prediction)} + return new_state_tuple, {"mean": self._scale_back_data(next_prediction)} def _imputation_step(self, current_times, state): """Advance model state across a gap.""" diff --git a/tensorflow/contrib/timeseries/python/timeseries/BUILD b/tensorflow/contrib/timeseries/python/timeseries/BUILD index 2c4bed5db199435fb28aa3023e4414492dc2d43a..5f04eb2f5a4af031ad19662b05a8a2396299925d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/BUILD +++ b/tensorflow/contrib/timeseries/python/timeseries/BUILD @@ -42,6 +42,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":feature_keys", + ":head", ":input_pipeline", ":model_utils", "//tensorflow/python:util", @@ -78,8 +79,8 @@ py_library( deps = [ ":ar_model", ":feature_keys", + ":head", ":math_utils", - ":model_utils", ":state_management", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:filtering_postprocessor", "//tensorflow/contrib/timeseries/python/timeseries/state_space_models:state_space_model", @@ -104,6 +105,7 @@ py_test( tags = [ "no_pip_gpu", # b/63391119 "nomsan", # Takes too long to run. + "notsan", # b/67865658 ], deps = [ ":ar_model", @@ -123,9 +125,9 @@ py_test( ) py_library( - name = "model_utils", + name = "head", srcs = [ - "model_utils.py", + "head.py", ], srcs_version = "PY2AND3", deps = [ @@ -136,22 +138,20 @@ py_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:init_ops", "//tensorflow/python:math_ops", - "//tensorflow/python:nn_ops", "//tensorflow/python:state_ops", "//tensorflow/python:util", "//tensorflow/python:variable_scope", "//tensorflow/python/estimator:estimator_py", "//tensorflow/python/estimator:export", - "//third_party/py/numpy", + "//tensorflow/python/estimator:head", ], ) py_test( - name = "model_utils_test", + name = "head_test", srcs = [ - "model_utils_test.py", + "head_test.py", ], srcs_version = "PY2AND3", tags = [ @@ -159,8 +159,8 @@ py_test( ], deps = [ ":feature_keys", + ":head", ":model", - ":model_utils", ":state_management", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", @@ -174,6 +174,39 @@ py_test( ], ) +py_library( + name = "model_utils", + srcs = [ + "model_utils.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":feature_keys", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:nn_ops", + "//tensorflow/python:variable_scope", + "//third_party/py/numpy", + ], +) + +py_test( + name = "model_utils_test", + srcs = [ + "model_utils_test.py", + ], + srcs_version = "PY2AND3", + tags = [ + "no_pip_gpu", # b/63391119 + ], + deps = [ + ":model_utils", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + ], +) + py_library( name = "state_management", srcs = [ @@ -290,11 +323,11 @@ py_library( ":input_pipeline", ":state_management", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:random_seed", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:util", "//tensorflow/python:variables", @@ -335,6 +368,7 @@ py_test( "ar_model_test.py", ], srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":ar_model", ":estimators", @@ -342,10 +376,10 @@ py_test( ":input_pipeline", ":test_utils", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_ops", "//tensorflow/python:platform", + "//tensorflow/python:session", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py index 7452dc7dc362b304ca3b3717bad039df17012e5c..ff140efd48104e386826eab7abbc94bec220f9df 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py @@ -89,8 +89,6 @@ class ARModel(model.TimeSeriesModel): self.hidden_layer_sizes = hidden_layer_sizes self.window_size = self.input_window_size + self.output_window_size self.loss = loss - self.stats_means = None - self.stats_sigmas = None super(ARModel, self).__init__( num_features=num_features) assert num_time_buckets > 0 @@ -106,32 +104,6 @@ class ARModel(model.TimeSeriesModel): assert len(self._periods) or self.input_window_size assert output_window_size > 0 - def scale_data(self, data): - """Scale data according to stats.""" - if self._input_statistics is not None: - return (data - self.stats_means) / self.stats_sigmas - else: - return data - - def scale_back_data(self, data): - if self._input_statistics is not None: - return (data * self.stats_sigmas) + self.stats_means - else: - return data - - def scale_back_variance(self, var): - if self._input_statistics is not None: - return var * self.stats_sigmas * self.stats_sigmas - else: - return var - - def initialize_graph(self, input_statistics=None): - super(ARModel, self).initialize_graph(input_statistics=input_statistics) - if self._input_statistics: - self.stats_means, variances = ( - self._input_statistics.overall_feature_moments) - self.stats_sigmas = math_ops.sqrt(variances) - def get_start_state(self): # State which matches the format we'll return later. Typically this will not # be used by the model directly, but the shapes and dtypes should match so @@ -388,8 +360,8 @@ class ARModel(model.TimeSeriesModel): predicted_covariance = array_ops.ones_like(predicted_mean) # Transform and scale the mean and covariance appropriately. - predicted_mean = self.scale_back_data(predicted_mean) - predicted_covariance = self.scale_back_variance(predicted_covariance) + predicted_mean = self._scale_back_data(predicted_mean) + predicted_covariance = self._scale_back_variance(predicted_covariance) return {"mean": predicted_mean, "covariance": predicted_covariance} @@ -402,7 +374,7 @@ class ARModel(model.TimeSeriesModel): original_values = values # Extra shape checking for the window size (above that in - # model_utils.make_model_fn). + # `head.create_estimator_spec`). expected_times_shape = [None, self.window_size] if not times.get_shape().is_compatible_with(expected_times_shape): raise ValueError( @@ -418,7 +390,7 @@ class ARModel(model.TimeSeriesModel): times_feature=TrainEvalFeatures.TIMES, window_size=self.window_size, times_shape=times.get_shape())) - values = self.scale_data(values) + values = self._scale_data(values) if self.input_window_size > 0: input_values = values[:, :self.input_window_size, :] else: @@ -435,14 +407,14 @@ class ARModel(model.TimeSeriesModel): # (observed - predicted) ** 2. # Note that this affects only evaluation; the training loss is unaffected. loss = self.loss_op( - self.scale_back_data(targets), - {"mean": self.scale_back_data(prediction_ops["mean"])}) + self._scale_back_data(targets), + {"mean": self._scale_back_data(prediction_ops["mean"])}) else: loss = self.loss_op(targets, prediction_ops) # Scale back the prediction. - prediction = self.scale_back_data(prediction) - covariance = self.scale_back_variance(covariance) + prediction = self._scale_back_data(prediction) + covariance = self._scale_back_variance(covariance) return model.ModelOutputs( loss=loss, @@ -565,7 +537,7 @@ class ARModel(model.TimeSeriesModel): new_state_times.set_shape((None, self.input_window_size)) new_state_values = array_ops.concat( [previous_state_values, - self.scale_data(values)], axis=1)[:, -self.input_window_size:, :] + self._scale_data(values)], axis=1)[:, -self.input_window_size:, :] new_state_values.set_shape((None, self.input_window_size, self.num_features)) else: diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py index 4025a8f0142b68c275122dac7ee384341d07163a..3738dfa154d4f39b9562446972443ed88f3fbe8b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py +++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py @@ -20,8 +20,8 @@ from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import ar_model from tensorflow.contrib.timeseries.python.timeseries import feature_keys +from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib from tensorflow.contrib.timeseries.python.timeseries import math_utils -from tensorflow.contrib.timeseries.python.timeseries import model_utils from tensorflow.contrib.timeseries.python.timeseries import state_management from tensorflow.contrib.timeseries.python.timeseries.state_space_models import state_space_model from tensorflow.contrib.timeseries.python.timeseries.state_space_models import structural_ensemble @@ -59,9 +59,10 @@ class TimeSeriesRegressor(estimator_lib.Estimator): if optimizer is None: optimizer = train.AdamOptimizer(0.02) self._model = model - model_fn = model_utils.make_model_fn( + ts_regression_head = ts_head_lib.time_series_regression_head( model, state_manager, optimizer, input_statistics_generator=input_statistics_generator) + model_fn = ts_regression_head.create_estimator_spec super(TimeSeriesRegressor, self).__init__( model_fn=model_fn, model_dir=model_dir, @@ -132,7 +133,7 @@ class TimeSeriesRegressor(estimator_lib.Estimator): with ops.Graph().as_default(): self._model.initialize_graph() model_start_state = self._model.get_start_state() - for prefixed_state_name, state_tensor in model_utils.state_to_dictionary( + for prefixed_state_name, state_tensor in ts_head_lib.state_to_dictionary( model_start_state).items(): state_shape_with_batch = tensor_shape.TensorShape( (default_batch_size,)).concatenate(state_tensor.get_shape()) diff --git a/tensorflow/contrib/timeseries/python/timeseries/head.py b/tensorflow/contrib/timeseries/python/timeseries/head.py new file mode 100644 index 0000000000000000000000000000000000000000..5896fc2a206bc747688b5b012e0f87465592dd8a --- /dev/null +++ b/tensorflow/contrib/timeseries/python/timeseries/head.py @@ -0,0 +1,375 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Timeseries head.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import re + +from tensorflow.contrib.framework.python.ops import variables +from tensorflow.contrib.layers.python.layers import optimizers + +from tensorflow.contrib.timeseries.python.timeseries import feature_keys + +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.export import export_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.util import nest + + +def time_series_regression_head(model, + state_manager, + optimizer, + input_statistics_generator=None): + """Creates a `_Head` for time series regression. + + Args: + model: A model for time series regression. + state_manager: A state manager. + optimizer: An optimizer. + input_statistics_generator: A input statistics generator. + + Returns: + An instance of `_Head` for time series regression. + """ + return _TimeSeriesRegressionHead(model, state_manager, optimizer, + input_statistics_generator) + + +class _TimeSeriesRegressionHead(head_lib._Head): # pylint:disable=protected-access + """See `time_series_regression_head`.""" + + def __init__(self, + model, + state_manager, + optimizer, + input_statistics_generator=None, + name=None): + self.model = model + self.state_manager = state_manager + self.optimizer = optimizer + self.input_statistics_generator = input_statistics_generator + self._name = name + + def _train_ops(self, features): + """Add training ops to the graph.""" + with variable_scope.variable_scope("model"): + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.TRAIN) + + train_op = optimizers.optimize_loss( + model_outputs.loss, + global_step=variables.get_global_step(), + optimizer=self.optimizer, + # Learning rate is set in the Optimizer object + learning_rate=None) + return estimator_lib.EstimatorSpec( + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.TRAIN, + train_op=train_op) + + # TODO(terrytangyuan): suffix summary and metrics keys by `"/" + name` + @property + def name(self): + return self._name + + # TODO(terrytangyuan): unused for now. Need to decouple + # `state_manager.define_loss` to satisfy the extendable return signature of + # `_Head.create_loss`. + def create_loss(self, features, mode, logits, labels): + """See `_Head`.""" + return None + + # TODO(terrytangyuan): check label dimension + @property + def logits_dimension(self): + return None + + def _evaluate_ops(self, features): + """Add ops for evaluation (aka filtering) to the graph.""" + with variable_scope.variable_scope("model"): + model_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) + metrics = {} + # Just output in-sample predictions for the last chunk seen + for prediction_key, prediction_value in model_outputs.predictions.items(): + metrics[prediction_key] = _identity_metric_single(prediction_key, + prediction_value) + metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single( + feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) + metrics[feature_keys.FilteringResults.STATE_TUPLE] = ( + _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, + model_outputs.end_state)) + return estimator_lib.EstimatorSpec( + loss=model_outputs.loss, + mode=estimator_lib.ModeKeys.EVAL, + eval_metric_ops=metrics, + predictions={}) + + def _predict_ops(self, features): + """Add ops for prediction to the graph.""" + with variable_scope.variable_scope("model"): + prediction = self.model.predict(features=features) + prediction[feature_keys.PredictionResults.TIMES] = features[ + feature_keys.PredictionFeatures.TIMES] + return estimator_lib.EstimatorSpec( + predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) + + def _serving_ops(self, features): + """Add ops for serving to the graph.""" + with variable_scope.variable_scope("model"): + prediction_outputs = self.model.predict(features=features) + with variable_scope.variable_scope("model", reuse=True): + filtering_outputs = self.state_manager.define_loss( + self.model, features, estimator_lib.ModeKeys.EVAL) + + return estimator_lib.EstimatorSpec( + mode=estimator_lib.ModeKeys.PREDICT, + export_outputs={ + feature_keys.SavedModelLabels.PREDICT: + export_lib.PredictOutput(prediction_outputs), + feature_keys.SavedModelLabels.FILTER: + export_lib.PredictOutput( + state_to_dictionary(filtering_outputs.end_state)) + }, + # Likely unused, but it is necessary to return `predictions` to satisfy + # the Estimator's error checking. + predictions={}) + + def _convert_feature_to_tensor(self, name, value): + """Casts features to the correct dtype based on their name.""" + if name in [ + feature_keys.TrainEvalFeatures.TIMES, + feature_keys.PredictionFeatures.TIMES + ]: + return math_ops.cast(value, dtypes.int64) + if name == feature_keys.TrainEvalFeatures.VALUES: + return math_ops.cast(value, self.model.dtype) + if name == feature_keys.PredictionFeatures.STATE_TUPLE: + return value # Correct dtypes are model-dependent + return ops.convert_to_tensor(value) + + def _gather_state(self, features): + """Returns `features` with state packed, indicates if packing was done.""" + prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX + + r"_(\d+)$") + numbered_state = [] + for key, tensor in features.items(): + search_result = prefixed_state_re.search(key) + if search_result: + numbered_state.append((int(search_result.group(1)), key, tensor)) + if not numbered_state: + return features, False + features = features.copy() + for _, key, _ in numbered_state: + del features[key] + numbered_state.sort(key=lambda number, *_: number) + features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as( + structure=self.model.get_start_state(), + flat_sequence=[tensor for _, _, tensor in numbered_state]) + return features, True + + def create_estimator_spec(self, features, mode, labels=None): + """Performs basic error checking and returns an EstimatorSpec.""" + with ops.name_scope("head"): + if labels: + raise ValueError( + "The model received a `labels` dictionary, which is " + "not supported. Pass '{}' and '{}' as " + "features.".format(feature_keys.TrainEvalFeatures.TIMES, + feature_keys.TrainEvalFeatures.VALUES)) + del labels + features = { + name: self._convert_feature_to_tensor(name=name, value=value) + for name, value in features.items() + } + if self.input_statistics_generator is not None: + input_statistics = self.input_statistics_generator.initialize_graph( + features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) + else: + input_statistics = None + self.model.initialize_graph(input_statistics=input_statistics) + + # _gather_state requires the model to have its graph initialized (so it + # has access to the structure of the model's state) + features, passed_flat_state = self._gather_state(features) + if (mode == estimator_lib.ModeKeys.TRAIN or + mode == estimator_lib.ModeKeys.EVAL): + _check_train_eval_features(features, self.model) + elif mode == estimator_lib.ModeKeys.PREDICT: + _check_predict_features(features) + else: + raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) + + self.state_manager.initialize_graph( + model=self.model, input_statistics=input_statistics) + + if mode == estimator_lib.ModeKeys.TRAIN: + return self._train_ops(features) + elif mode == estimator_lib.ModeKeys.EVAL: + return self._evaluate_ops(features) + elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state: + return self._predict_ops(features) + elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state: + # The mode is PREDICT, but we're actually in export_savedmodel for + # serving. We want to return two graphs: one for filtering (state + data + # -> state) and one for predicting (state -> prediction). + return self._serving_ops(features) + + +def _check_feature_shapes_compatible_with(features, + compatible_with_name, + compatible_with_value, + ignore=None): + """Checks all features are compatible with the given time-like feature.""" + if ignore is None: + ignore = set() + for name, value in features.items(): + if name in ignore: + continue + feature_shape = value.get_shape() + if feature_shape.ndims is None: + continue + if feature_shape.ndims < 2: + raise ValueError( + ("Features must have shape (batch dimension, window size, ...) " + "(got rank {} for feature '{}')").format(feature_shape.ndims, name)) + if not feature_shape[:2].is_compatible_with( + compatible_with_value.get_shape()): + raise ValueError( + ("Features must have shape (batch dimension, window size, ...) " + "where batch dimension and window size match the " + "'{times_feature}' feature (got shape {feature_shape} for " + "feature '{feature_name}' but shape {times_shape} for feature " + "'{times_feature}')").format( + times_feature=compatible_with_name, + feature_shape=feature_shape, + feature_name=name, + times_shape=compatible_with_value.get_shape())) + + +def _check_predict_features(features): + """Raises errors if features are not suitable for prediction.""" + if feature_keys.PredictionFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.TIMES)) + if feature_keys.PredictionFeatures.STATE_TUPLE not in features: + raise ValueError("Expected a '{}' feature for prediction.".format( + feature_keys.PredictionFeatures.STATE_TUPLE)) + times_feature = features[feature_keys.PredictionFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, + times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.PredictionFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes + ])) + + +def _check_train_eval_features(features, model): + """Raise errors if features are not suitable for training/evaluation.""" + if feature_keys.TrainEvalFeatures.TIMES not in features: + raise ValueError("Expected a '{}' feature for training/evaluation.".format( + feature_keys.TrainEvalFeatures.TIMES)) + if feature_keys.TrainEvalFeatures.VALUES not in features: + raise ValueError("Expected a '{}' feature for training/evaluation.".format( + feature_keys.TrainEvalFeatures.VALUES)) + times_feature = features[feature_keys.TrainEvalFeatures.TIMES] + if not times_feature.get_shape().is_compatible_with([None, None]): + raise ValueError( + ("Expected shape (batch dimension, window size) for feature '{}' " + "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, + times_feature.get_shape())) + values_feature = features[feature_keys.TrainEvalFeatures.VALUES] + if not values_feature.get_shape().is_compatible_with( + [None, None, model.num_features]): + raise ValueError( + ("Expected shape (batch dimension, window size, {num_features}) " + "for feature '{feature_name}', since the model was configured " + "with num_features={num_features} (got shape {got_shape})").format( + num_features=model.num_features, + feature_name=feature_keys.TrainEvalFeatures.VALUES, + got_shape=times_feature.get_shape())) + _check_feature_shapes_compatible_with( + features=features, + compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, + compatible_with_value=times_feature, + ignore=set([ + feature_keys.State.STATE_TUPLE # Model-dependent shapes + ])) + + +def _identity_metric_single(name, input_tensor): + """A metric which takes on its last updated value. + + This keeps evaluation metrics in sync with one another, since update ops are + run separately from their result Tensors. Simply returning (input_tensor, + no_op) as a metric with a value but no update means that a metric will come + from a different batch of data than metrics which cache values in a Variable + (e.g. the default loss metric). + + Args: + name: A name for the metric. + input_tensor: Any Tensor. + Returns: + A tuple of (value, update_op). + """ + metric_variable = variable_scope.variable( + name="{}_identity_metric".format(name), + initial_value=array_ops.zeros([], dtype=input_tensor.dtype), + collections=[ops.GraphKeys.LOCAL_VARIABLES], + validate_shape=False) + update_op = state_ops.assign( + metric_variable, input_tensor, validate_shape=False) + # This shape will be correct once the first update runs (but may be + # incomplete, so is not helpful for initializing the variable). + metric_variable.set_shape(input_tensor.get_shape()) + return (metric_variable.value(), update_op) + + +def _identity_metric_nested(name, input_tensors): + """Create identity metrics for a nested tuple of Tensors.""" + update_ops = [] + value_tensors = [] + for tensor_number, tensor in enumerate(nest.flatten(input_tensors)): + value_tensor, update_op = _identity_metric_single( + name="{}_{}".format(name, tensor_number), input_tensor=tensor) + update_ops.append(update_op) + value_tensors.append(value_tensor) + return (nest.pack_sequence_as(input_tensors, value_tensors), + control_flow_ops.group(*update_ops)) + + +def state_to_dictionary(state_tuple): + """Flatten model state into a dictionary with string keys.""" + flattened = {} + for state_number, state_value in enumerate(nest.flatten(state_tuple)): + prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX, + state_number) + flattened[prefixed_state_name] = state_value + return flattened diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py new file mode 100644 index 0000000000000000000000000000000000000000..3415061cfd87358cccaf36dcb301fb36986bbde6 --- /dev/null +++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py @@ -0,0 +1,267 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for head.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.timeseries.python.timeseries import feature_keys +from tensorflow.contrib.timeseries.python.timeseries import head as ts_head_lib +from tensorflow.contrib.timeseries.python.timeseries import model +from tensorflow.contrib.timeseries.python.timeseries import state_management + +from tensorflow.python.estimator import estimator_lib +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import metrics +from tensorflow.python.ops import variables +from tensorflow.python.platform import test +from tensorflow.python.training import coordinator as coordinator_lib +from tensorflow.python.training import queue_runner_impl +from tensorflow.python.training import training as train + + +class HeadTest(test.TestCase): + + def test_labels_provided_error(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, + estimator_lib.ModeKeys.PREDICT]: + with self.assertRaisesRegexp(ValueError, "labels"): + model_fn(features={}, labels={"a": "b"}, mode=mode) + + def test_unknown_mode(self): + model_fn = _stub_model_fn() + with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): + model_fn(features={}, labels={}, mode="Not a mode") + + +class _TickerModel(object): + num_features = 1 + dtype = dtypes.float32 + + def initialize_graph(self, input_statistics): + pass + + def define_loss(self, features, mode): + del mode # unused + return model.ModelOutputs( + loss=features["ticker"], + end_state=(features["ticker"], features["ticker"]), + prediction_times=array_ops.zeros(()), + predictions={"ticker": features["ticker"]}) + + +class EvaluationMetricsTests(test.TestCase): + + def test_metrics_consistent(self): + # Tests that the identity metrics used to report in-sample predictions match + # the behavior of standard metrics. + g = ops.Graph() + with g.as_default(): + features = { + feature_keys.TrainEvalFeatures.TIMES: + array_ops.zeros((1, 1)), + feature_keys.TrainEvalFeatures.VALUES: + array_ops.zeros((1, 1, 1)), + "ticker": + array_ops.reshape( + math_ops.cast( + variables.Variable( + name="ticker", + initial_value=0, + dtype=dtypes.int64, + collections=[ops.GraphKeys.LOCAL_VARIABLES]) + .count_up_to(10), + dtype=dtypes.float32), (1, 1, 1)) + } + model_fn = ts_head_lib.time_series_regression_head( + model=_TickerModel(), + state_manager=state_management.PassthroughStateManager(), + optimizer=train.GradientDescentOptimizer(0.001)).create_estimator_spec + outputs = model_fn( + features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL) + metric_update_ops = [ + metric[1] for metric in outputs.eval_metric_ops.values()] + loss_mean, loss_update = metrics.mean(outputs.loss) + metric_update_ops.append(loss_update) + with self.test_session() as sess: + coordinator = coordinator_lib.Coordinator() + queue_runner_impl.start_queue_runners(sess, coord=coordinator) + variables.local_variables_initializer().run() + sess.run(metric_update_ops) + loss_evaled, metric_evaled, nested_metric_evaled = sess.run( + (loss_mean, outputs.eval_metric_ops["ticker"][0], + outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][ + 0][0])) + # The custom model_utils metrics for in-sample predictions should be in + # sync with the Estimator's mean metric for model loss. + self.assertAllClose(0., loss_evaled) + self.assertAllClose((((0.,),),), metric_evaled) + self.assertAllClose((((0.,),),), nested_metric_evaled) + coordinator.request_stop() + coordinator.join() + + +class _StubModel(object): + num_features = 3 + dtype = dtypes.float64 + + def initialize_graph(self, input_statistics): + del input_statistics # unused + + +def _stub_model_fn(): + return ts_head_lib.time_series_regression_head( + model=_StubModel(), + state_manager=state_management.PassthroughStateManager(), + optimizer=train.AdamOptimizer(0.001)).create_estimator_spec + + +class TrainEvalFeatureCheckingTests(test.TestCase): + + def test_no_time_feature(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( + feature_keys.TrainEvalFeatures.TIMES)): + model_fn( + features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]}, + labels=None, + mode=mode) + + def test_no_value_feature(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( + feature_keys.TrainEvalFeatures.VALUES)): + model_fn( + features={feature_keys.TrainEvalFeatures.TIMES: [[1]]}, + labels=None, + mode=mode) + + def test_bad_time_rank(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp(ValueError, + "Expected shape.*for feature '{}'".format( + feature_keys.TrainEvalFeatures.TIMES)): + model_fn( + features={ + feature_keys.TrainEvalFeatures.TIMES: [[[1]]], + feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] + }, + labels=None, + mode=mode) + + def test_bad_value_rank(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp(ValueError, + "Expected shape.*for feature '{}'".format( + feature_keys.TrainEvalFeatures.VALUES)): + model_fn( + features={ + feature_keys.TrainEvalFeatures.TIMES: [[1]], + feature_keys.TrainEvalFeatures.VALUES: [[1.]] + }, + labels=None, + mode=mode) + + def test_bad_value_num_features(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp( + ValueError, "Expected shape.*, 3.*for feature '{}'".format( + feature_keys.TrainEvalFeatures.VALUES)): + model_fn( + features={ + feature_keys.TrainEvalFeatures.TIMES: [[1]], + feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] + }, + labels=None, + mode=mode) + + def test_bad_exogenous_shape(self): + model_fn = _stub_model_fn() + for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: + with self.assertRaisesRegexp( + ValueError, + "Features must have shape.*for feature 'exogenous'"): + model_fn( + features={ + feature_keys.TrainEvalFeatures.TIMES: [[1]], + feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]], + "exogenous": [[1], [2]] + }, + labels=None, + mode=mode) + + +class PredictFeatureCheckingTests(test.TestCase): + + def test_no_time_feature(self): + model_fn = _stub_model_fn() + with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( + feature_keys.PredictionFeatures.TIMES)): + model_fn( + features={ + feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.) + }, + labels=None, + mode=estimator_lib.ModeKeys.PREDICT) + + def test_no_start_state_feature(self): + model_fn = _stub_model_fn() + with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( + feature_keys.PredictionFeatures.STATE_TUPLE)): + model_fn( + features={feature_keys.PredictionFeatures.TIMES: [[1]]}, + labels=None, + mode=estimator_lib.ModeKeys.PREDICT) + + def test_bad_time_rank(self): + model_fn = _stub_model_fn() + with self.assertRaisesRegexp(ValueError, + "Expected shape.*for feature '{}'".format( + feature_keys.PredictionFeatures.TIMES)): + model_fn( + features={ + feature_keys.PredictionFeatures.TIMES: 1, + feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)) + }, + labels=None, + mode=estimator_lib.ModeKeys.PREDICT) + + def test_bad_exogenous_shape(self): + model_fn = _stub_model_fn() + with self.assertRaisesRegexp( + ValueError, + "Features must have shape.*for feature 'exogenous'"): + model_fn( + features={ + feature_keys.PredictionFeatures.TIMES: [[1]], + feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)), + "exogenous": 1. + }, + labels=None, + mode=estimator_lib.ModeKeys.PREDICT) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py index c70da3e082245e76ab3225676c2d37c4ea95292d..23452a81c397da3516016d72b7bc9b80f7d6447f 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/math_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/math_utils.py @@ -936,8 +936,7 @@ class InputStatisticsFromMiniBatch(object): start_time = variable_scope.get_variable( name="start_time", dtype=dtypes.int64, - initializer=init_ops.zeros_initializer(), - shape=[], + initializer=dtypes.int64.max, trainable=False) total_observation_count = variable_scope.get_variable( name="total_observation_count", diff --git a/tensorflow/contrib/timeseries/python/timeseries/model.py b/tensorflow/contrib/timeseries/python/timeseries/model.py index f2ef8d22114be50a10d3b106be5e144cc70b4bfc..b32b5c5494ae14187954b900119678a5b53a3602 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model.py @@ -80,6 +80,8 @@ class TimeSeriesModel(object): self.dtype = dtype self._input_statistics = None self._graph_initialized = False + self._stats_means = None + self._stats_sigmas = None # TODO(allenl): Move more of the generic machinery for generating and # predicting into TimeSeriesModel, and possibly share it between generate() @@ -120,6 +122,38 @@ class TimeSeriesModel(object): """ self._graph_initialized = True self._input_statistics = input_statistics + if self._input_statistics: + self._stats_means, variances = ( + self._input_statistics.overall_feature_moments) + self._stats_sigmas = math_ops.sqrt(variances) + + def _scale_data(self, data): + """Scale data according to stats (input scale -> model scale).""" + if self._input_statistics is not None: + return (data - self._stats_means) / self._stats_sigmas + else: + return data + + def _scale_variance(self, variance): + """Scale variances according to stats (input scale -> model scale).""" + if self._input_statistics is not None: + return variance / self._input_statistics.overall_feature_moments.variance + else: + return variance + + def _scale_back_data(self, data): + """Scale back data according to stats (model scale -> input scale).""" + if self._input_statistics is not None: + return (data * self._stats_sigmas) + self._stats_means + else: + return data + + def _scale_back_variance(self, variance): + """Scale back variances according to stats (model scale -> input scale).""" + if self._input_statistics is not None: + return variance * self._input_statistics.overall_feature_moments.variance + else: + return variance def _check_graph_initialized(self): if not self._graph_initialized: @@ -304,6 +338,7 @@ class SequentialTimeSeriesModel(TimeSeriesModel): train_output_names, predict_output_names, num_features, + normalize_features=False, dtype=dtypes.float32, exogenous_feature_columns=None, exogenous_update_condition=None, @@ -316,6 +351,12 @@ class SequentialTimeSeriesModel(TimeSeriesModel): predict_output_names: A list of products/predictions returned from _prediction_step. num_features: Number of features for the time series + normalize_features: Boolean. If True, `values` are passed normalized to + the model (via self._scale_data). Scaling is done for the whole window + as a batch, which is slightly more efficient than scaling inside the + window loop. The model must then define _scale_back_predictions, which + may use _scale_back_data or _scale_back_variance to return predictions + to the input scale. dtype: The floating point datatype to use. exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn objects. See `TimeSeriesModel`. @@ -344,9 +385,25 @@ class SequentialTimeSeriesModel(TimeSeriesModel): self._exogenous_update_condition = exogenous_update_condition self._train_output_names = train_output_names self._predict_output_names = predict_output_names + self._normalize_features = normalize_features self._static_unrolling_window_size_threshold = ( static_unrolling_window_size_threshold) + def _scale_back_predictions(self, predictions): + """Return a window of predictions to input scale. + + Args: + predictions: A dictionary mapping from prediction names to Tensors. + Returns: + A dictionary with values corrected for input normalization (e.g. with + self._scale_back_mean and possibly self._scale_back_variance). May be a + mutated version of the argument. + """ + raise NotImplementedError( + "SequentialTimeSeriesModel normalized input data" + " (normalize_features=True), but no method was provided to transform " + "the predictions back to the input scale.") + @abc.abstractmethod def _filtering_step(self, current_times, current_values, state, predictions): """Compute a single-step loss for a batch of data. @@ -524,6 +581,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): self._check_graph_initialized() times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64) values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype) + if self._normalize_features: + values = self._scale_data(values) exogenous_regressors = self._process_exogenous_features( times=times, features={key: value for key, value in features.items() @@ -556,6 +615,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): # Since we have window-level additions to the loss, its per-step value is # misleading, so we avoid returning it. del outputs["loss"] + if self._normalize_features: + outputs = self._scale_back_predictions(outputs) return per_observation_loss, state, outputs def predict(self, features): @@ -583,6 +644,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel): times=predict_times, state=start_state, state_update_fn=_call_prediction_step, outputs=self._predict_output_names) + if self._normalize_features: + predictions = self._scale_back_predictions(predictions) return predictions class _FakeTensorArray(object): diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py index addcdb05754c6ccd736f5d21619015acfcfc906c..b5d7cb376b6337113e8653fcbc376aa1e228464a 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils.py @@ -18,334 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -import re - import numpy -from tensorflow.contrib.framework.python.ops import variables -from tensorflow.contrib.layers.python.layers import optimizers - from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.python.estimator import estimator_lib -from tensorflow.python.estimator.export import export_lib from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import init_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops -from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope -from tensorflow.python.util import nest - - -def _check_feature_shapes_compatible_with( - features, compatible_with_name, compatible_with_value, ignore=None): - """Checks all features are compatible with the given time-like feature.""" - if ignore is None: - ignore = set() - for name, value in features.items(): - if name in ignore: - continue - feature_shape = value.get_shape() - if feature_shape.ndims is None: - continue - if feature_shape.ndims < 2: - raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "(got rank {} for feature '{}')").format( - feature_shape.ndims, name)) - if not feature_shape[:2].is_compatible_with( - compatible_with_value.get_shape()): - raise ValueError( - ("Features must have shape (batch dimension, window size, ...) " - "where batch dimension and window size match the " - "'{times_feature}' feature (got shape {feature_shape} for " - "feature '{feature_name}' but shape {times_shape} for feature " - "'{times_feature}')").format( - times_feature=compatible_with_name, - feature_shape=feature_shape, - feature_name=name, - times_shape=compatible_with_value.get_shape())) - - -def _check_predict_features(features): - """Raises errors if features are not suitable for prediction.""" - if feature_keys.PredictionFeatures.TIMES not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.TIMES)) - if feature_keys.PredictionFeatures.STATE_TUPLE not in features: - raise ValueError("Expected a '{}' feature for prediction.".format( - feature_keys.PredictionFeatures.STATE_TUPLE)) - times_feature = features[feature_keys.PredictionFeatures.TIMES] - if not times_feature.get_shape().is_compatible_with([None, None]): - raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.PredictionFeatures.TIMES, - times_feature.get_shape())) - _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.PredictionFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.PredictionFeatures.STATE_TUPLE # Model-dependent shapes - ])) - - -def _check_train_eval_features(features, model): - """Raise errors if features are not suitable for training/evaluation.""" - if feature_keys.TrainEvalFeatures.TIMES not in features: - raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.TIMES)) - if feature_keys.TrainEvalFeatures.VALUES not in features: - raise ValueError("Expected a '{}' feature for training/evaluation.".format( - feature_keys.TrainEvalFeatures.VALUES)) - times_feature = features[feature_keys.TrainEvalFeatures.TIMES] - if not times_feature.get_shape().is_compatible_with([None, None]): - raise ValueError( - ("Expected shape (batch dimension, window size) for feature '{}' " - "(got shape {})").format(feature_keys.TrainEvalFeatures.TIMES, - times_feature.get_shape())) - values_feature = features[feature_keys.TrainEvalFeatures.VALUES] - if not values_feature.get_shape().is_compatible_with( - [None, None, model.num_features]): - raise ValueError( - ("Expected shape (batch dimension, window size, {num_features}) " - "for feature '{feature_name}', since the model was configured " - "with num_features={num_features} (got shape {got_shape})").format( - num_features=model.num_features, - feature_name=feature_keys.TrainEvalFeatures.VALUES, - got_shape=times_feature.get_shape())) - _check_feature_shapes_compatible_with( - features=features, - compatible_with_name=feature_keys.TrainEvalFeatures.TIMES, - compatible_with_value=times_feature, - ignore=set([ - feature_keys.State.STATE_TUPLE # Model-dependent shapes - ])) - - -def _identity_metric_single(name, input_tensor): - """A metric which takes on its last updated value. - - This keeps evaluation metrics in sync with one another, since update ops are - run separately from their result Tensors. Simply returning (input_tensor, - no_op) as a metric with a value but no update means that a metric will come - from a different batch of data than metrics which cache values in a Variable - (e.g. the default loss metric). - - Args: - name: A name for the metric. - input_tensor: Any Tensor. - Returns: - A tuple of (value, update_op). - """ - metric_variable = variable_scope.variable( - name="{}_identity_metric".format(name), - initial_value=array_ops.zeros([], dtype=input_tensor.dtype), - collections=[ops.GraphKeys.LOCAL_VARIABLES], - validate_shape=False) - update_op = state_ops.assign(metric_variable, input_tensor, - validate_shape=False) - # This shape will be correct once the first update runs (but may be - # incomplete, so is not helpful for initializing the variable). - metric_variable.set_shape(input_tensor.get_shape()) - return (metric_variable.value(), update_op) - - -def _identity_metric_nested(name, input_tensors): - """Create identity metrics for a nested tuple of Tensors.""" - update_ops = [] - value_tensors = [] - for tensor_number, tensor in enumerate(nest.flatten(input_tensors)): - value_tensor, update_op = _identity_metric_single( - name="{}_{}".format(name, tensor_number), - input_tensor=tensor) - update_ops.append(update_op) - value_tensors.append(value_tensor) - return (nest.pack_sequence_as(input_tensors, value_tensors), - control_flow_ops.group(*update_ops)) - - -def state_to_dictionary(state_tuple): - """Flatten model state into a dictionary with string keys.""" - flattened = {} - for state_number, state_value in enumerate(nest.flatten(state_tuple)): - prefixed_state_name = "{}_{:02d}".format(feature_keys.State.STATE_PREFIX, - state_number) - flattened[prefixed_state_name] = state_value - return flattened - - -def make_model_fn( - model, state_manager, optimizer, input_statistics_generator=None): - """Returns a model function suitable for use with a tf.estimator. - - Args: - model: The object (inheriting from Model) to create a function for. - state_manager: A state manager to wrap the model with (or - PassthroughStateManager if no state needs to be managed). - optimizer: An instance of `tf.train.Optimizer` to use for training. - input_statistics_generator: An InputStatisticsFromMiniBatch object from - math_utils.py, used for collecting statistics about input data during - training. - Returns: - The model function, suitable for passing to a tf.estimator.Estimator. - """ - - def _convert_feature_to_tensor(name, value): - """Casts features to the correct dtype based on their name.""" - if name in [ - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.PredictionFeatures.TIMES - ]: - return math_ops.cast(value, dtypes.int64) - if name == feature_keys.TrainEvalFeatures.VALUES: - return math_ops.cast(value, model.dtype) - if name == feature_keys.PredictionFeatures.STATE_TUPLE: - return value # Correct dtypes are model-dependent - return ops.convert_to_tensor(value) - - def _gather_state(features): - """Returns `features` with state packed, indicates if packing was done.""" - prefixed_state_re = re.compile(r"^" + feature_keys.State.STATE_PREFIX + - r"_(\d+)$") - numbered_state = [] - for key, tensor in features.items(): - search_result = prefixed_state_re.search(key) - if search_result: - numbered_state.append((int(search_result.group(1)), key, tensor)) - if not numbered_state: - return features, False - features = features.copy() - for _, key, _ in numbered_state: - del features[key] - numbered_state.sort(key=lambda number, *_: number) - features[feature_keys.State.STATE_TUPLE] = nest.pack_sequence_as( - structure=model.get_start_state(), - flat_sequence=[tensor for _, _, tensor in numbered_state]) - return features, True - - def _train(features): - """Add training ops to the graph.""" - with variable_scope.variable_scope("model"): - model_outputs = state_manager.define_loss(model, features, - estimator_lib.ModeKeys.TRAIN) - train_op = optimizers.optimize_loss( - model_outputs.loss, - global_step=variables.get_global_step(), - optimizer=optimizer, - # Learning rate is set in the Optimizer object - learning_rate=None) - return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.TRAIN, - train_op=train_op) - - def _evaluate(features): - """Add ops for evaluation (aka filtering) to the graph.""" - with variable_scope.variable_scope("model"): - model_outputs = state_manager.define_loss(model, features, - estimator_lib.ModeKeys.EVAL) - metrics = {} - # Just output in-sample predictions for the last chunk seen - for prediction_key, prediction_value in model_outputs.predictions.items(): - metrics[prediction_key] = _identity_metric_single(prediction_key, - prediction_value) - metrics[feature_keys.FilteringResults.TIMES] = _identity_metric_single( - feature_keys.FilteringResults.TIMES, model_outputs.prediction_times) - metrics[feature_keys.FilteringResults.STATE_TUPLE] = ( - _identity_metric_nested(feature_keys.FilteringResults.STATE_TUPLE, - model_outputs.end_state)) - return estimator_lib.EstimatorSpec( - loss=model_outputs.loss, - mode=estimator_lib.ModeKeys.EVAL, - eval_metric_ops=metrics, - predictions={}) - - def _predict(features): - """Add ops for prediction to the graph.""" - with variable_scope.variable_scope("model"): - prediction = model.predict(features=features) - prediction[feature_keys.PredictionResults.TIMES] = features[ - feature_keys.PredictionFeatures.TIMES] - return estimator_lib.EstimatorSpec( - predictions=prediction, mode=estimator_lib.ModeKeys.PREDICT) - - def _serving(features): - with variable_scope.variable_scope("model"): - prediction_outputs = model.predict(features=features) - with variable_scope.variable_scope("model", reuse=True): - filtering_outputs = state_manager.define_loss(model, features, - estimator_lib.ModeKeys.EVAL) - return estimator_lib.EstimatorSpec( - mode=estimator_lib.ModeKeys.PREDICT, - export_outputs={ - feature_keys.SavedModelLabels.PREDICT: - export_lib.PredictOutput(prediction_outputs), - feature_keys.SavedModelLabels.FILTER: - export_lib.PredictOutput( - state_to_dictionary(filtering_outputs.end_state)) - }, - # Likely unused, but it is necessary to return `predictions` to satisfy - # the Estimator's error checking. - predictions={}) - - def _model_fn(features, labels, mode): - """Given a time series in `features`, define a loss for `mode`. - - Args: - features: A dictionary, the output of a chunker (typically with keys - feature_keys.TrainEvalFeatures.TIMES and - feature_keys.TrainEvalFeatures.VALUES). - labels: Not used; included for compatibility with tf.learn. - mode: The tf.estimator.ModeKeys mode to use (TRAIN, EVAL, INFER). - Returns: - A tuple of predictions, a loss Tensor, and a train op. - Raises: - ValueError: If the model makes predictions which do not have static shape - information. - """ - if labels: - raise ValueError("The model received a `labels` dictionary, which is not" - " supported. Pass '{}' and '{}' as features.".format( - feature_keys.TrainEvalFeatures.TIMES, - feature_keys.TrainEvalFeatures.VALUES)) - del labels - features = {name: _convert_feature_to_tensor(name=name, value=value) - for name, value in features.items()} - if input_statistics_generator is not None: - input_statistics = input_statistics_generator.initialize_graph( - features, update_statistics=(mode == estimator_lib.ModeKeys.TRAIN)) - else: - input_statistics = None - model.initialize_graph(input_statistics=input_statistics) - # _gather_state requires the model to have its graph initialized (so it has - # access to the structure of the model's state) - features, passed_flat_state = _gather_state(features) - if (mode == estimator_lib.ModeKeys.TRAIN - or mode == estimator_lib.ModeKeys.EVAL): - _check_train_eval_features(features, model) - elif mode == estimator_lib.ModeKeys.PREDICT: - _check_predict_features(features) - else: - raise ValueError("Unknown mode '{}' passed to model_fn.".format(mode)) - state_manager.initialize_graph( - model=model, input_statistics=input_statistics) - if mode == estimator_lib.ModeKeys.TRAIN: - return _train(features) - elif mode == estimator_lib.ModeKeys.EVAL: - return _evaluate(features) - elif mode == estimator_lib.ModeKeys.PREDICT and not passed_flat_state: - return _predict(features) - elif mode == estimator_lib.ModeKeys.PREDICT and passed_flat_state: - # The mode is PREDICT, but we're actually in export_savedmodel for - # serving. We want to return two graphs: one for filtering (state + data - # -> state) and one for predicting (state -> prediction). - return _serving(features) - return _model_fn # TODO(agarwal): Remove and replace with functionality from tf.slim diff --git a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py index 29986895549d16f37c7ff929a30f9a63a56be135..cfd31cc70d8165b2880293060656209c057f5028 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/model_utils_test.py @@ -18,22 +18,10 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.contrib.timeseries.python.timeseries import feature_keys -from tensorflow.contrib.timeseries.python.timeseries import model from tensorflow.contrib.timeseries.python.timeseries import model_utils -from tensorflow.contrib.timeseries.python.timeseries import state_management -from tensorflow.python.estimator import estimator_lib -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops -from tensorflow.python.ops import metrics -from tensorflow.python.ops import variables from tensorflow.python.platform import test -from tensorflow.python.training import coordinator as coordinator_lib -from tensorflow.python.training import queue_runner_impl -from tensorflow.python.training import training as train class ModelUtilsTest(test.TestCase): @@ -46,230 +34,6 @@ class ModelUtilsTest(test.TestCase): self.assertEqual(5, getter(parameter)) self.assertEqual(4, getter(overridden_parameter)) - def test_labels_provided_error(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL, - estimator_lib.ModeKeys.PREDICT]: - with self.assertRaisesRegexp(ValueError, "labels"): - model_fn(features={}, labels={"a": "b"}, mode=mode) - - def test_unknown_mode(self): - model_fn = _stub_model_fn() - with self.assertRaisesRegexp(ValueError, "Unknown mode 'Not a mode'"): - model_fn(features={}, labels={}, mode="Not a mode") - - -class _TickerModel(object): - num_features = 1 - dtype = dtypes.float32 - - def initialize_graph(self, input_statistics): - pass - - def define_loss(self, features, mode): - del mode # unused - return model.ModelOutputs( - loss=features["ticker"], - end_state=(features["ticker"], features["ticker"]), - prediction_times=array_ops.zeros(()), - predictions={"ticker": features["ticker"]}) - - -class EvaluationMetricsTests(test.TestCase): - - def test_metrics_consistent(self): - # Tests that the identity metrics used to report in-sample predictions match - # the behavior of standard metrics. - g = ops.Graph() - with g.as_default(): - features = { - feature_keys.TrainEvalFeatures.TIMES: - array_ops.zeros((1, 1)), - feature_keys.TrainEvalFeatures.VALUES: - array_ops.zeros((1, 1, 1)), - "ticker": - array_ops.reshape( - math_ops.cast( - variables.Variable( - name="ticker", - initial_value=0, - dtype=dtypes.int64, - collections=[ops.GraphKeys.LOCAL_VARIABLES]) - .count_up_to(10), - dtype=dtypes.float32), (1, 1, 1)) - } - model_fn = model_utils.make_model_fn( - model=_TickerModel(), - state_manager=state_management.PassthroughStateManager(), - optimizer=train.GradientDescentOptimizer(0.001)) - outputs = model_fn( - features=features, labels=None, mode=estimator_lib.ModeKeys.EVAL) - metric_update_ops = [ - metric[1] for metric in outputs.eval_metric_ops.values()] - loss_mean, loss_update = metrics.mean(outputs.loss) - metric_update_ops.append(loss_update) - with self.test_session() as sess: - coordinator = coordinator_lib.Coordinator() - queue_runner_impl.start_queue_runners(sess, coord=coordinator) - variables.local_variables_initializer().run() - sess.run(metric_update_ops) - loss_evaled, metric_evaled, nested_metric_evaled = sess.run( - (loss_mean, outputs.eval_metric_ops["ticker"][0], - outputs.eval_metric_ops[feature_keys.FilteringResults.STATE_TUPLE][ - 0][0])) - # The custom model_utils metrics for in-sample predictions should be in - # sync with the Estimator's mean metric for model loss. - self.assertAllClose(0., loss_evaled) - self.assertAllClose((((0.,),),), metric_evaled) - self.assertAllClose((((0.,),),), nested_metric_evaled) - coordinator.request_stop() - coordinator.join() - - -class _StubModel(object): - num_features = 3 - dtype = dtypes.float64 - - def initialize_graph(self, input_statistics): - del input_statistics # unused - - -def _stub_model_fn(): - return model_utils.make_model_fn( - model=_StubModel(), - state_manager=state_management.PassthroughStateManager(), - optimizer=train.AdamOptimizer(0.001)) - - -class TrainEvalFeatureCheckingTests(test.TestCase): - - def test_no_time_feature(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( - feature_keys.TrainEvalFeatures.TIMES)): - model_fn( - features={feature_keys.TrainEvalFeatures.VALUES: [[[1.]]]}, - labels=None, - mode=mode) - - def test_no_value_feature(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( - feature_keys.TrainEvalFeatures.VALUES)): - model_fn( - features={feature_keys.TrainEvalFeatures.TIMES: [[1]]}, - labels=None, - mode=mode) - - def test_bad_time_rank(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp(ValueError, - "Expected shape.*for feature '{}'".format( - feature_keys.TrainEvalFeatures.TIMES)): - model_fn( - features={ - feature_keys.TrainEvalFeatures.TIMES: [[[1]]], - feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] - }, - labels=None, - mode=mode) - - def test_bad_value_rank(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp(ValueError, - "Expected shape.*for feature '{}'".format( - feature_keys.TrainEvalFeatures.VALUES)): - model_fn( - features={ - feature_keys.TrainEvalFeatures.TIMES: [[1]], - feature_keys.TrainEvalFeatures.VALUES: [[1.]] - }, - labels=None, - mode=mode) - - def test_bad_value_num_features(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp( - ValueError, "Expected shape.*, 3.*for feature '{}'".format( - feature_keys.TrainEvalFeatures.VALUES)): - model_fn( - features={ - feature_keys.TrainEvalFeatures.TIMES: [[1]], - feature_keys.TrainEvalFeatures.VALUES: [[[1.]]] - }, - labels=None, - mode=mode) - - def test_bad_exogenous_shape(self): - model_fn = _stub_model_fn() - for mode in [estimator_lib.ModeKeys.TRAIN, estimator_lib.ModeKeys.EVAL]: - with self.assertRaisesRegexp( - ValueError, - "Features must have shape.*for feature 'exogenous'"): - model_fn( - features={ - feature_keys.TrainEvalFeatures.TIMES: [[1]], - feature_keys.TrainEvalFeatures.VALUES: [[[1., 2., 3.]]], - "exogenous": [[1], [2]] - }, - labels=None, - mode=mode) - - -class PredictFeatureCheckingTests(test.TestCase): - - def test_no_time_feature(self): - model_fn = _stub_model_fn() - with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( - feature_keys.PredictionFeatures.TIMES)): - model_fn( - features={ - feature_keys.PredictionFeatures.STATE_TUPLE: ([[[1.]]], 1.) - }, - labels=None, - mode=estimator_lib.ModeKeys.PREDICT) - - def test_no_start_state_feature(self): - model_fn = _stub_model_fn() - with self.assertRaisesRegexp(ValueError, "Expected a '{}' feature".format( - feature_keys.PredictionFeatures.STATE_TUPLE)): - model_fn( - features={feature_keys.PredictionFeatures.TIMES: [[1]]}, - labels=None, - mode=estimator_lib.ModeKeys.PREDICT) - - def test_bad_time_rank(self): - model_fn = _stub_model_fn() - with self.assertRaisesRegexp(ValueError, - "Expected shape.*for feature '{}'".format( - feature_keys.PredictionFeatures.TIMES)): - model_fn( - features={ - feature_keys.PredictionFeatures.TIMES: 1, - feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)) - }, - labels=None, - mode=estimator_lib.ModeKeys.PREDICT) - - def test_bad_exogenous_shape(self): - model_fn = _stub_model_fn() - with self.assertRaisesRegexp( - ValueError, - "Features must have shape.*for feature 'exogenous'"): - model_fn( - features={ - feature_keys.PredictionFeatures.TIMES: [[1]], - feature_keys.PredictionFeatures.STATE_TUPLE: (1, (2, 3.)), - "exogenous": 1. - }, - labels=None, - mode=estimator_lib.ModeKeys.PREDICT) - if __name__ == "__main__": test.main() diff --git a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py index 16e29f5e68e4c7c0bbb0b5cd0c547ac57e2faa9f..97f6d36a879532c12684ffdd700ef40b72750567 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py +++ b/tensorflow/contrib/timeseries/python/timeseries/saved_model_utils.py @@ -23,6 +23,7 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.timeseries.python.timeseries import feature_keys as _feature_keys +from tensorflow.contrib.timeseries.python.timeseries import head as _head from tensorflow.contrib.timeseries.python.timeseries import input_pipeline as _input_pipeline from tensorflow.contrib.timeseries.python.timeseries import model_utils as _model_utils @@ -34,7 +35,7 @@ def _colate_features_to_feeds_and_fetches(continue_from, signature, features, """Uses a saved model signature to construct feed and fetch dictionaries.""" if _feature_keys.FilteringResults.STATE_TUPLE in continue_from: # We're continuing from an evaluation, so we need to unpack/flatten state. - state_values = _model_utils.state_to_dictionary( + state_values = _head.state_to_dictionary( continue_from[_feature_keys.FilteringResults.STATE_TUPLE]) else: state_values = continue_from diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py index b9d3f55c39d32bb9f14829842fcad85571de6855..56167c4f012b42a4e7d56c5e6eac7862d50bd59b 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/level_trend.py @@ -57,7 +57,9 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel): # TODO(allenl): Better support for multivariate series here. initial_value = array_ops.stack([ math_ops.reduce_mean( - self._input_statistics.series_start_moments.mean), 0. + self._scale_data( + self._input_statistics.series_start_moments.mean)), + 0. ]) return initial_value + variable_scope.get_variable( name="prior_state_mean", diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py index 6a9660b400d08a0397103676344ea1969fbc1f7a..6257002647ed53bbde3ace11a6b45e4e2cdeb57d 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model.py @@ -232,6 +232,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): + filtering_postprocessor_names), predict_output_names=["mean", "covariance"], num_features=configuration.num_features, + normalize_features=True, dtype=configuration.dtype, exogenous_feature_columns=configuration.exogenous_feature_columns, exogenous_update_condition=configuration.exogenous_update_condition, @@ -309,15 +310,10 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): _, _, priors_from_time = state times = ops.convert_to_tensor(times) priors_from_time = ops.convert_to_tensor(priors_from_time) - with ops.control_dependencies([ - control_flow_ops.Assert( - math_ops.reduce_all(priors_from_time <= times[:, 0]), - [priors_from_time, times[:, 0]], - summarize=100) - ]): - times = array_ops.identity(times) intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1]) - starting_gaps = times[:, 0] - priors_from_time + # Ignore negative starting gaps, since there will be transient start times + # as inputs statistics are computed. + starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0) # Pre-define transition matrices raised to powers (and their sums) for every # gap in this window. This avoids duplicate computation (for example many # steps will use the transition matrix raised to the first power) and @@ -369,20 +365,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): Imputed model state corresponding to the `state` argument. """ estimated_state, estimated_state_var, previous_times = state - catchup_times = current_times - previous_times - non_negative_assertion = control_flow_ops.Assert( - math_ops.reduce_all(catchup_times >= 0), [ - "Negative imputation interval", catchup_times, current_times, - previous_times - ], - summarize=100) - with ops.control_dependencies([non_negative_assertion]): - transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking - self._cached_transition_powers_and_sums(catchup_times)) - estimated_state = self._kalman_filter.predict_state_mean( - estimated_state, transition_matrices) - estimated_state_var = self._kalman_filter.predict_state_var( - estimated_state_var, transition_matrices, transition_noise_sums) + # Ignore negative imputation intervals due to transient start time + # estimates. + catchup_times = math_ops.maximum(current_times - previous_times, 0) + transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking + self._cached_transition_powers_and_sums(catchup_times)) + estimated_state = self._kalman_filter.predict_state_mean( + estimated_state, transition_matrices) + estimated_state_var = self._kalman_filter.predict_state_var( + estimated_state_var, transition_matrices, transition_noise_sums) return (estimated_state, estimated_state_var, previous_times + catchup_times) @@ -437,6 +428,13 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): outputs=predictions) return (filtered_state, predictions) + def _scale_back_predictions(self, predictions): + """Return a window of predictions to input scale.""" + predictions["mean"] = self._scale_back_data(predictions["mean"]) + predictions["covariance"] = self._scale_back_variance( + predictions["covariance"]) + return predictions + def _prediction_step(self, current_times, state): """Make a prediction based on `state`. @@ -458,7 +456,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): """ estimated_state, estimated_state_var, previous_times = state advanced_to_current_assert = control_flow_ops.Assert( - math_ops.reduce_all(math_ops.equal(current_times, previous_times)), + math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)), ["Attempted to predict without imputation"]) with ops.control_dependencies([advanced_to_current_assert]): observation_model = self.get_broadcasted_observation_model(current_times) @@ -475,6 +473,9 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): (self.num_features,))) predicted_obs_var.set_shape(current_times.get_shape().concatenate( (self.num_features, self.num_features))) + # Not scaled back to input-scale, since this also feeds into the + # loss. Instead, predictions are scaled back before being returned to the + # user in _scale_back_predictions. predictions = { "mean": predicted_obs, "covariance": predicted_obs_var} @@ -722,7 +723,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): # Make sure initial latent value uncertainty is at least on the same # scale as noise in the data. covariance_multiplier = math_ops.reduce_max( - self._input_statistics.series_start_moments.variance) + self._scale_variance( + self._input_statistics.series_start_moments.variance)) return base_covariance * gen_math_ops.maximum( covariance_multiplier, 1.0) else: @@ -920,7 +922,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): self.get_noise_transform(), dtype=self.dtype) state_noise_dimension = state_noise_transform.get_shape()[1].value if self._input_statistics is not None: - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) initial_transition_noise_scale = math_ops.log( gen_math_ops.maximum( math_ops.reduce_mean(feature_variance) / math_ops.cast( @@ -945,7 +948,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel): if self._input_statistics is not None: # Get variance across the first few values in each batch for each # feature, for an initial observation noise (over-)estimate. - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) else: feature_variance = None if feature_variance is not None: diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py index 7c8f81ec5165b8ba7e8a1089953e5755b5a90915..ca57715e2b2e6bbadd276d641703c0a3b842652e 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/state_space_model_test.py @@ -605,6 +605,7 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel): super(TimeDependentStateSpaceModel, self).__init__( configuration=state_space_model.StateSpaceModelConfiguration( use_observation_noise=False, + transition_covariance_initial_log_scale_bias=5., static_unrolling_window_size_threshold= static_unrolling_window_size_threshold)) diff --git a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py index 110ba9738f8c28109282b927fd07ade071bb3e4a..1afc58cfb240c52a9f001da787addfb7fbb46789 100644 --- a/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py +++ b/tensorflow/contrib/timeseries/python/timeseries/state_space_models/varma.py @@ -182,7 +182,8 @@ class VARMA(state_space_model.StateSpaceModel): # modeled as transition noise in VARMA, we set its initial value based on a # slight over-estimate empirical observation noise. if self._input_statistics is not None: - feature_variance = self._input_statistics.series_start_moments.variance + feature_variance = self._scale_variance( + self._input_statistics.series_start_moments.variance) initial_transition_noise_scale = math_ops.log( math_ops.maximum( math_ops.reduce_mean(feature_variance), minimum_initial_variance)) diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD index e753fe7a5140028f238c2ff3754b1d7335ae8eb2..e14c36ae43f2544db4ed1e855097a7658120b892 100644 --- a/tensorflow/contrib/tpu/BUILD +++ b/tensorflow/contrib/tpu/BUILD @@ -30,27 +30,49 @@ cc_library( ], ) +py_library( + name = "tpu_test_util", + srcs = ["python/tpu/test_util.py"], + srcs_version = "PY2AND3", + deps = [ + ":tpu_lib", + ":tpu_py", + "//tensorflow/python:errors", + "//tensorflow/python:framework_ops", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:session", + "//tensorflow/python:variables", + ], +) + py_library( name = "tpu_estimator", srcs = [ "python/tpu/tpu_config.py", "python/tpu/tpu_estimator.py", + "python/tpu/util.py", ], srcs_version = "PY2AND3", deps = [ ":tpu_lib", ":tpu_py", + "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:init_ops", + "//tensorflow/python:math_ops", "//tensorflow/python:platform", + "//tensorflow/python:state_ops", + "//tensorflow/python:summary", "//tensorflow/python:training", "//tensorflow/python:variable_scope", "//tensorflow/python:variables", "//tensorflow/python/estimator", "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:run_config", "//tensorflow/python/estimator:util", + "@six_archive//:six", ], ) @@ -95,6 +117,7 @@ py_library( srcs_version = "PY2AND3", deps = [ "//tensorflow/contrib/tpu/profiler:trace_events_proto_py", + "//tensorflow/python:util", ], ) @@ -111,21 +134,15 @@ tf_custom_op_py_library( ":tpu_ops", "//tensorflow/contrib/util:util_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:math_ops", "//tensorflow/python:platform", - "//tensorflow/python:state_ops", - "//tensorflow/python:variable_scope", - "//tensorflow/python:variables", + "//tensorflow/python:util", ], ) py_library( name = "tpu", - srcs = [ - "python/tpu/__init__.py", - ], + srcs = ["python/tpu/__init__.py"], srcs_version = "PY2AND3", deps = [ ":tpu_estimator", @@ -198,7 +215,9 @@ tf_py_test( filegroup( name = "all_files", srcs = glob( - ["**/*"], + include = [ + "**/*", + ], exclude = [ "**/METADATA", "**/OWNERS", diff --git a/tensorflow/contrib/tpu/ops/replication_ops.cc b/tensorflow/contrib/tpu/ops/replication_ops.cc index a40e2a7898a304c21a60929b30719f3132aec0f0..b40dac471708793d5a033279e2d2f4b4a0dac480 100644 --- a/tensorflow/contrib/tpu/ops/replication_ops.cc +++ b/tensorflow/contrib/tpu/ops/replication_ops.cc @@ -22,6 +22,11 @@ namespace tensorflow { using shape_inference::InferenceContext; using shape_inference::ShapeHandle; +REGISTER_OP("TPUReplicateMetadata") + .Attr("num_replicas: int >= 0") + .Attr("global_tpu_id: list(int) = []") + .SetShapeFn(shape_inference::UnknownShape); + REGISTER_OP("TPUReplicatedInput") .Input("inputs: N * T") .Output("output: T") diff --git a/tensorflow/contrib/tpu/profiler/BUILD b/tensorflow/contrib/tpu/profiler/BUILD index f6309e2e72f75a4ba5b323b4d7348c49555d522e..0e1fca3d3c8b6f3a19b3e989dbee1863475796c5 100644 --- a/tensorflow/contrib/tpu/profiler/BUILD +++ b/tensorflow/contrib/tpu/profiler/BUILD @@ -95,3 +95,10 @@ tf_proto_library_cc( cc_api_version = 2, visibility = ["//visibility:public"], ) + +tf_proto_library_cc( + name = "tf_op_stats_proto", + srcs = ["tf_op_stats.proto"], + cc_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/tpu/profiler/tf_op_stats.proto b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto new file mode 100644 index 0000000000000000000000000000000000000000..5b2dbb31243d401fbab31bab5bc86133896693fe --- /dev/null +++ b/tensorflow/contrib/tpu/profiler/tf_op_stats.proto @@ -0,0 +1,127 @@ +// This proto describes the format of tensorflow operation level stats for +// profiling (in tensorboard) purpose. + +syntax = "proto2"; + +package tensorflow.tpu; + +// Result proto for OpMetrics. +message OpMetricsResult { + // True if this OP is executed on the device; False if it is executed on the + // host. + optional bool on_device = 1; + reserved 2; // was uint32 id. + // Name of this OP. + optional string name = 3; + // Rank of this OP. + optional uint64 rank = 4; + // The starting time in cycles of the last instance of this OP executed. + optional double last_starttime_in_cycles = 5; + // The ending time in cycles of the last instance of this OP executed. + optional double last_endtime_in_cycles = 6; + // If this OP (say A), is an immediate child of another OP (say B), this field + // stores the sum of duration in microseconds of A inside B. If A appears more + // than once in B, the duration of all A's appearances will be added together. + // This sum will be reset after the self-time of B is calculated so that it + // can be reused for a new parent OP. + optional double sum_of_duration_in_us_as_children = 7; + // Number of instances that this OP occurred. + optional uint64 occurrences = 8; + // Total time in microseconds spent in this OP (accumulated + // over all of its occurrences). + optional double total_time_in_us = 9; + // Total self time in microseconds spent in this OP + // (accumulated over all of its occurrences). + optional double total_self_time_in_us = 10; + // The total self time as a fraction of sum of all OP's + // total self time on the host. + optional double host_total_self_time_as_fraction_of_all_op_time = 11; + // Cumulative total self time in fraction on the host. + optional double host_cumulative_total_self_time_as_fraction_of_all_op_time = + 12; + // The total self time as a fraction of sum of all OP's + // total self time on the device. + optional double device_total_self_time_as_fraction_of_all_op_time = 13; + // Cumulative total self time in fraction on the device. + optional double device_cumulative_total_self_time_as_fraction_of_all_op_time = + 14; + // Total number of FLOPs incurred by this OP. + optional double total_flops = 15; + // Total time in microseconds that the MXU is occupied by this OP. + optional double total_bytes_accessed = 16; + // Total time in microseconds that the MXU is occupied by this OP. + optional double mxu_occupancy_in_us = 17; + // Total time in microseconds that the XU is occupied by this OP. + optional double xu_occupancy_in_us = 18; + // Total DMA access stall time in microseconds. + optional double total_dma_stall_in_us = 19; +} + +// Result proto for OpMetricsDb. +message OpMetricsDbResult { + // A bunch of OpMetricsResults. + repeated OpMetricsResult metrics_db = 1; +} + +// Result proto for StepInfo. +message StepInfoResult { + // The (micro) step number. + optional uint32 step_num = 1; + // The step duration in picoseconds. + optional uint64 duration_ps = 2; + // The infeed duration in picoseconds. + // Can turn into a map if we want a variable number of ops. + optional uint64 infeed_duration_ps = 3; +} + +// Result proto for a sequence of steps. +message StepSequenceResult { + // A sequence of StepInfoResults. + repeated StepInfoResult step_sequence = 1; +} + +// Result proto for a StepDatabase. +message StepDatabaseResult { + // A map from core_id to StepSequenceResult. + map step_sequence_per_core = 1; +} + +// Result proto for Dashboard data. +message DashboardResult { + // The total iteration time in nanoseconds. + optional double iteration_time_ns = 1; + // The total number of iterations. + optional int32 num_iterations = 2; + // The total computation time in nanoseconds. + optional double computation_time_ns = 3; + // The total number of computations. + optional int32 num_computations = 4; +} + +// Result proto for HloExtraInfo. +message HloExtraInfoResult { + // Category of the HLO op given by the compiler. + optional string category = 1; + // The long name of the HLO that includes the dimensions. + optional string long_name = 2; +} + +// Result proto for HloExtraInfoMap. +message HloExtraInfoMapResult { + // A map from HLO name to HloExtraInfo. + map hlo_extrainfo_map = 1; +} + +// Result proto for TfStatsHelper. +message TfOpStats { + // The result for the TF-metric database. + optional OpMetricsDbResult tf_metrics_db = 1; + // The result for the HLO-metric database. + optional OpMetricsDbResult hlo_metrics_db = 2; + // The result for the step database. + optional StepDatabaseResult step_db = 3; + // The result for the TPU dashboard. + optional DashboardResult dashboard = 4; + // The result for the HloExtraInfoMap. + optional HloExtraInfoMapResult hlo_extrainfo_map = 5; +} diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py index 8d3344fac36be24a692f141eee140312d988a932..33e47f674d798f622fb08121dabb67d7f45af15b 100644 --- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py +++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py @@ -21,9 +21,11 @@ from __future__ import print_function import platform +from tensorflow.python.framework import ops if platform.system() != "Windows": # pylint: disable=wildcard-import,unused-import,g-import-not-at-top + from tensorflow.contrib.tpu.ops import gen_tpu_ops from tensorflow.contrib.tpu.ops.gen_tpu_ops import * from tensorflow.contrib.util import loader @@ -32,6 +34,12 @@ if platform.system() != "Windows": _tpu_ops = loader.load_op_library( resource_loader.get_path_to_datafile("_tpu_ops.so")) + + @ops.RegisterGradient("CrossReplicaSum") + def _cross_replica_sum_grad(op, grad): + del op # Unused + # The gradient of a cross replica sum is also a cross-replica sum. + return gen_tpu_ops.cross_replica_sum(grad) else: # We have already built the appropriate libraries into the binary via CMake # if we have built contrib, so we don't need this diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py new file mode 100644 index 0000000000000000000000000000000000000000..f30c27f1298e2389fe0daefdd4eece5a03a6976c --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/test_util.py @@ -0,0 +1,153 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== +"""Utilities to ease testing on TPU devices.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.tpu.python.tpu import tpu + +from tensorflow.python.client import session +from tensorflow.python.framework import errors +from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util +from tensorflow.python.ops import gen_array_ops +from tensorflow.python.ops import variables + + +def has_tpu(): + """Check if a TPU device is available. + + Device enumeration via `device_lib` currently fails for TPU systems. + (http://b/68333779). To work around this, we determine the existence of a + TPU by a successful call to `initialize_system`. + + Returns: + boolean, True if a TPU device is available, otherwise False. + """ + def _check(): + with session.Session() as sess: + sess.run(tpu.initialize_system()) + sess.run(tpu.shutdown_system()) + + try: + _check() + return True + except errors.OpError as _: + return False + + +def _available_devices(): + devices = ["cpu"] + if not test_util.gpu_device_name(): + devices.append("gpu") + + if has_tpu(): + devices.append("tpu") + + return tuple(devices) + + +class TPUTestCase(test_util.TensorFlowTestCase): + """Adds helpers for testing on TPU devices to `TensorFlowTestCase`. + + Example usage: + + ``` + def model_fn(features): + return tf.reduce_sum(features * 2) + + class ModelTests(test_util.TPUTestCase): + def test_sum(self): + v = np.random.randn(10, 10).astype("float32") + self.assert_device_output(model_fn, [v], (v*2).sum(), + devices=("cpu", "tpu")) + ``` + """ + + def __init__(self, methodName="runTest"): # pylint: disable=invalid-name + super(TPUTestCase, self).__init__(methodName) + self._available_devices = _available_devices() + + def run_on_device(self, model_fn, model_inputs, device): + """Runs `model_fn` on the given device. + + Raises an exception if no such device is available. `model_fn` should + return one or more tensors as a list or tuple. + + Args: + model_fn: Function returning one or more tensors. + model_inputs: An iterable of Numpy arrays or scalars. + These will be passed as arguments to `model_fn`. + device: Device to run on. One of ("tpu", "gpu", "cpu"). + + Returns: + Output from the model function. + """ + def _make_placeholders(): + return dict( + [(gen_array_ops.placeholder_with_default(v, v.shape), v) + for v in model_inputs]) + + if device == "tpu": + with self.test_session(graph=ops.Graph()) as sess: + placeholders = _make_placeholders() + tpu_computation = tpu.rewrite(model_fn, placeholders.keys()) + sess.run(tpu.initialize_system()) + sess.run(variables.global_variables_initializer()) + result = sess.run(tpu_computation, placeholders) + sess.run(tpu.shutdown_system()) + # TODO(b/36891278): supports non-flat returns lists in tpu.rewrite(). + if len(result) == 1: + return result[0] + return result + elif device == "gpu": + with self.test_session(graph=ops.Graph(), use_gpu=True) as sess: + placeholders = _make_placeholders() + sess.run(variables.global_variables_initializer()) + return sess.run(model_fn(placeholders.keys()), placeholders) + elif device == "cpu": + # TODO(power) -- will this interact poorly with cached GPU sessions? + with self.test_session(graph=ops.Graph(), use_gpu=False) as sess: + placeholders = _make_placeholders() + sess.run(variables.global_variables_initializer()) + return sess.run(model_fn(placeholders.keys()), placeholders) + + def _compare_values(self, actual_outputs, expected_outputs): + if isinstance(expected_outputs, (list, tuple)): + for a, b in zip(actual_outputs, expected_outputs): + self.assertAllCloseAccordingToType(a, b) + else: + self.assertAllCloseAccordingToType(actual_outputs, expected_outputs) + + def assert_device_output(self, model_fn, model_inputs, expected_outputs, + devices=("cpu", "gpu", "tpu")): + """Run `model_fn` on the given devices. + + Results are compared via `assertAllCloseAccordingToType`. + + Args: + model_fn: Function returning one or more tensors + model_inputs: Numpy arrays or scalars passed as arguments to model_fn + expected_outputs: Numpy arrays or scalars to compare against. + devices: Set of devices to run on. If a device is not available, tests + will be skipped for that device. + """ + devices = set(devices).intersection(self._available_devices) + + for device in devices: + device_out = self.run_on_device(model_fn, model_inputs, device=device) + self._compare_values(device_out, expected_outputs) diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py index f6800e3e246dc5f6242a7bf127f6397fedf92b9f..d521297d9947c2a9a37a7283e332591669e102ce 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu.py @@ -105,9 +105,8 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): """A ControlFlowContext for nodes inside a TPU computation. The primary role of TPUReplicateContext is to mark operators inside a - tpu.replicate() computation with attributes: - * _tpu_replicate=XYZ, where XYZ is a unique name, and - * _tpu_num_replicas=k, where k is the number of replicas. + tpu.replicate() computation with the attribute "_tpu_replicate=XYZ", where XYZ + is a unique name. We use a ControlFlowContext to perform the annotation since it integrates with Tensorflow constructs like ResourceVariables. For example, @@ -116,11 +115,9 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): to build the variable's definition outside the replicated computation. """ - def __init__(self, name, num_replicas, global_tpu_id=None): + def __init__(self, name): control_flow_ops.ControlFlowContext.__init__(self) self._name = name - self._num_replicas = num_replicas - self._global_tpu_id = [] if global_tpu_id is None else global_tpu_id def AddOp(self, op): self._AddOpInternal(op) @@ -135,8 +132,6 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): if "_tpu_replicate" in op.node_def.attr: raise ValueError("TPU computations cannot be nested") op.node_def.attr["_tpu_replicate"].s = self._name - op.node_def.attr["_tpu_num_replicas"].i = self._num_replicas - op.node_def.attr["_tpu_global_id"].list.i.extend(self._global_tpu_id) op.graph.prevent_feeding(op) op.graph.prevent_fetching(op) @@ -151,6 +146,14 @@ class TPUReplicateContext(control_flow_ops.ControlFlowContext): if self._outer_context: self._outer_context.AddInnerOp(op) + @property + def grad_state(self): + # Define the gradient loop state associated with the TPUReplicateContext to + # be None as the TPUReplicateContext does not get nested nor does the + # grad_state outside the TPUReplicateContext affect the graph inside so the + # grad_state should be as if this is the top-level gradient state. + return None + def replicate(computation, inputs=None, @@ -243,14 +246,15 @@ def replicate(computation, computation_inputs.append( tpu_ops.tpu_replicated_input(replicas, name="input{}".format(i))) - context = TPUReplicateContext( - name=graph.unique_name("cluster"), - num_replicas=num_replicas, - global_tpu_id=global_tpu_id) + context = TPUReplicateContext(name=graph.unique_name("cluster")) try: context.Enter() - with tpu_function.tpu_shard_context(num_replicas): + metadata = tpu_ops.tpu_replicate_metadata( + num_replicas=num_replicas, global_tpu_id=global_tpu_id) + + with tpu_function.tpu_shard_context( + num_replicas), ops.control_dependencies([metadata]): # The EncapsulateTPUComputations rewrite needs to identify the # replicated arguments inside each computation. Adds identity operators @@ -315,8 +319,11 @@ def replicate(computation, # because the TPUReplicatedInput/TPUReplicatedOutput operator would not # be rewritten away, leading to a runtime error. # TODO(phawkins): extend the rewrite to elide these nodes instead. - with ops.device(core(0)): - output_tensors = [array_ops.identity(x) for x in output_tensors] + new_output_tensors = [] + for t in output_tensors: + with ops.device(t.device if t.device else core(0)): + new_output_tensors.append(array_ops.identity(t)) + output_tensors = new_output_tensors finally: context.Exit() diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_config.py b/tensorflow/contrib/tpu/python/tpu/tpu_config.py index 02135bfe40e72860d474f441b6cc57430d4e0fca..3965c087a18dc18298703fad9b1dda9c85c56271 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_config.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_config.py @@ -21,12 +21,16 @@ from __future__ import print_function import collections +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.python.estimator import run_config as run_config_lib class TPUConfig( collections.namedtuple('TPUConfig', [ - 'iterations_per_loop', 'num_shards', 'per_host_input_for_training' + 'iterations_per_loop', + 'num_shards', + 'per_host_input_for_training', + 'tpu_job_name', ])): """TPU related configuration required by `TPUEstimator`. @@ -36,31 +40,62 @@ class TPUConfig( global step is increased `iterations_per_loop` times in one `Session.run`. It is recommended to be set as number of global steps for next checkpoint. num_shards: The number of TPU shards in the system. - per_host_input_for_training: If `True`, `input_fn` is invoked per host - rather than per shard. Note: This behavior is going to be default as - `True` soon, so this flag will be removed after that. Also note that this - only works for single-host TPU training now. + per_host_input_for_training: If `True`, `input_fn` is invoked Per-Host + rather than Per-Core. With Per-Host input pipeline deployment, `input_fn` + is invoked once on each host. To be precise, with a global batch size + `train_batch_size` in `TPUEstimator` constructor, the batch size for each + shard is `train_batch_size` // #hosts. With Per-Core input pipeline + deployment, the shard batch size is `train_batch_size` // #cores. Note + that this only works for single-host TPU training now (tracked in + b/67051042). For multi-host, please use Per-Core, i.e., `False` for + `per_host_input_for_training`. + tpu_job_name: The name of the TPU job. Typically, this name is auto-inferred + within TPUEstimator, however when using ClusterSpec propagation in more + esoteric cluster configurations, you may need to specify the job name as a + string. """ def __new__(cls, iterations_per_loop=2, num_shards=2, - per_host_input_for_training=False): + per_host_input_for_training=True, + tpu_job_name=None): + + # Check iterations_per_loop. + util_lib.check_positive_integer(iterations_per_loop, + 'TPUConfig iterations_per_loop') + + # Check num_shards. + util_lib.check_positive_integer(num_shards, 'TPUConfig num_shards') return super(TPUConfig, cls).__new__( cls, iterations_per_loop=iterations_per_loop, num_shards=num_shards, - per_host_input_for_training=per_host_input_for_training) + per_host_input_for_training=per_host_input_for_training, + tpu_job_name=tpu_job_name) class RunConfig(run_config_lib.RunConfig): """RunConfig with TPU support.""" - def __init__(self, tpu_config=None, evaluation_master='', master='', + def __init__(self, tpu_config=None, evaluation_master=None, master='', **kwargs): + """Constructs a RunConfig. + + Args: + tpu_config: the TPUConfig that specifies TPU-specific configuration. + evaluation_master: a string. The address of the master to use for eval. + Defaults to master if not set. + master: a string. The address of the master to use for training. + tf_random_seed: an int. Sets the TensorFlow random seed. Defaults to None, + which initializes it randomly based on the environment. + """ super(RunConfig, self).__init__(**kwargs) self._tpu_config = tpu_config or TPUConfig() - self._evaluation_master = evaluation_master + if evaluation_master is None: + self._evaluation_master = master + else: + self._evaluation_master = evaluation_master self._master = master @property diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py index cc9f27782a1a9ffdfb7f384ba85e10d80d3520f8..060b3f912926fbaa56bc1150e50434a7ad22c847 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +from contextlib import contextmanager import copy import threading import six @@ -31,12 +32,14 @@ from tensorflow.contrib.tpu.python.tpu import tpu_config from tensorflow.contrib.tpu.python.tpu import tpu_feed from tensorflow.contrib.tpu.python.tpu import tpu_function from tensorflow.contrib.tpu.python.tpu import training_loop +from tensorflow.contrib.tpu.python.tpu import util as util_lib from tensorflow.core.protobuf import config_pb2 from tensorflow.python.estimator import estimator as estimator_lib from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator import util +from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops @@ -56,12 +59,15 @@ from tensorflow.python.training import training_util _INITIAL_LOSS = 1e7 _ZERO_LOSS = 0. -_DEFAULT_NAME_SCOPE = 'tpu_estimator' +_TPU_ESTIMATOR = 'tpu_estimator' _ITERATIONS_PER_LOOP_VAR = 'iterations_per_loop' _BATCH_SIZE_KEY = 'batch_size' _CROSS_REPLICA_SUM_OP = 'CrossReplicaSum' _RESERVED_PARAMS_KEYS = [_BATCH_SIZE_KEY] +# TODO(b/65703635): Flip the value and remove all dead code. +_WRAP_INPUT_FN_INTO_WHILE_LOOP = True + def _create_global_step(graph): graph = graph or ops.get_default_graph() @@ -80,17 +86,25 @@ def _create_global_step(graph): ops.GraphKeys.GLOBAL_STEP]) -def _create_iterations_per_loop(): - with variable_scope.variable_scope(_DEFAULT_NAME_SCOPE, - reuse=variable_scope.AUTO_REUSE): - return variable_scope.get_variable( - _ITERATIONS_PER_LOOP_VAR, - initializer=init_ops.zeros_initializer(), - shape=[], - dtype=dtypes.int32, - trainable=False, - collections=[], - use_resource=True) +def _create_or_get_iterations_per_loop(): + graph = ops.get_default_graph() + iter_vars = graph.get_collection(_TPU_ESTIMATOR) + if len(iter_vars) == 1: + return iter_vars[0] + elif len(iter_vars) > 1: + raise RuntimeError('Multiple iterations_per_loop_var in collection.') + + with ops.colocate_with(training_util.get_global_step()): + with variable_scope.variable_scope(_TPU_ESTIMATOR, + reuse=variable_scope.AUTO_REUSE): + return variable_scope.get_variable( + _ITERATIONS_PER_LOOP_VAR, + initializer=init_ops.zeros_initializer(), + shape=[], + dtype=dtypes.int32, + trainable=False, + collections=[_TPU_ESTIMATOR], + use_resource=True) def _sync_variables_ops(): @@ -121,26 +135,214 @@ def _increase_eval_step_op(iterations_per_loop): use_locking=True) -def _tpu_job(run_config, mode): - # The tpu job is determined by the run_config. Right now, this method is - # required as tpu_config is not part of the RunConfig. - master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL - else run_config.master) - return None if master in ['', 'local'] else 'tpu_worker' +_DEFAULT_JOB_NAME = 'tpu_worker' +_DEFAULT_COORDINATOR_JOB_NAME = 'coordinator' +_LOCAL_MASTERS = ('', 'local') + + +class _TPUContext(object): + """A context holds immutable states of TPU computation. + + This immutable object holds TPUEstimator config, train/eval batch size, and + `TPUEstimator.use_tpu`, which is expected to be passed around. It also + provides utility functions, basded on the current state, to determine other + information commonly required by TPU computation, such as TPU device names, + TPU hosts, shard batch size, etc. + + N.B. As `mode` is not immutable state in Estimator, but essential to + distinguish between TPU training and evaluation, a common usage for + _TPUContext with `mode` is as follows: + ``` + with _ctx.with_mode(mode) as ctx: + if ctx.is_running_on_cpu(): + ... + ``` + """ + + def __init__(self, config, train_batch_size, eval_batch_size, use_tpu): + self._config = config + self._train_batch_size = train_batch_size + self._eval_batch_size = eval_batch_size + self._use_tpu = use_tpu + self._num_shards_or_none = self._config.tpu_config.num_shards + self._mode = None + + def _assert_mode(self): + if self._mode is None: + raise RuntimeError( + '`mode` needs to be set via contextmanager `with_mode`.') + return self._mode + + @property + def num_of_cores_per_host(self): + num_cores = self.num_cores + return min(num_cores, 8) + + @contextmanager + def with_mode(self, mode): + new_ctx = copy.copy(self) # Shallow copy is enough. + new_ctx._mode = mode # pylint: disable=protected-access + yield new_ctx + + @property + def mode(self): + return self._assert_mode() + + @property + def num_cores(self): + # TODO(xiejw): Adds lazy num_shards initialization. + return self._num_shards_or_none + + @property + def num_hosts(self): + return self.num_cores // self.num_of_cores_per_host + + @property + def config(self): + return self._config + + def is_input_sharded_per_core(self): + """Return true if input_fn is invoked per-core (other than per-host).""" + self._assert_mode() + return (self._mode == model_fn_lib.ModeKeys.TRAIN and + not self._config.tpu_config.per_host_input_for_training) + + def is_running_on_cpu(self): + """Determines whether the input_fn and model_fn should be invoked on CPU.""" + mode = self._assert_mode() + return ((not self._use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or + (mode == model_fn_lib.ModeKeys.EVAL and + self._eval_batch_size is None)) + + @property + def batch_size_for_input_fn(self): + """Returns the shard batch size for `input_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + global_batch_size = (self._train_batch_size if + mode == model_fn_lib.ModeKeys.TRAIN + else self._eval_batch_size) + # On TPU + return (global_batch_size // self.num_cores + if self.is_input_sharded_per_core() else global_batch_size) + + @property + def batch_size_for_model_fn(self): + """Returns the shard batch size for `model_fn`.""" + mode = self._assert_mode() + # Special case for eval. + if mode == model_fn_lib.ModeKeys.EVAL and self._eval_batch_size is None: + return None + if self.is_running_on_cpu(): + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size + if mode == model_fn_lib.ModeKeys.EVAL: + return self._eval_batch_size + return None + + # On TPU. always sharded per core. + if mode == model_fn_lib.ModeKeys.TRAIN: + return self._train_batch_size // self.num_cores + else: + return self._eval_batch_size // self.num_cores + + @property + def master_job(self): + """Returns the job name to use to place TPU computations on. + + Returns: + A string containing the job name, or None if no job should be specified. + + Raises: + ValueError: If the user needs to specify a tpu_job_name, because we are + unable to infer the job name automatically, or if the user-specified job + names are inappropriate. + """ + run_config = self._config + # If the user specifies the tpu_job_name, use that. + if run_config.tpu_config.tpu_job_name: + return run_config.tpu_config.tpu_job_name + + # The tpu job is determined by the run_config. Right now, this method is + # required as tpu_config is not part of the RunConfig. + mode = self._assert_mode() + master = (run_config.evaluation_master if mode == model_fn_lib.ModeKeys.EVAL + else run_config.master) + if master in _LOCAL_MASTERS: + return None + + if (not run_config.session_config or + not run_config.session_config.cluster_def.job): + return _DEFAULT_JOB_NAME + cluster_def = run_config.session_config.cluster_def + job_names = set([job.name for job in cluster_def.job]) + if _DEFAULT_JOB_NAME in job_names: + # b/37868888 tracks allowing ClusterSpec propagation to reuse job names. + raise ValueError('Currently, tpu_worker is not an allowed job name.') + if len(job_names) == 1: + return cluster_def.job[0].name + if len(job_names) == 2: + if _DEFAULT_COORDINATOR_JOB_NAME in job_names: + job_names.remove(_DEFAULT_COORDINATOR_JOB_NAME) + return job_names.pop() + # TODO(b/67716447): Include more sophisticated heuristics. + raise ValueError( + 'Could not infer TPU job name. Please specify a tpu_job_name as part ' + 'of your TPUConfig.') + + @property + def tpu_host_placement_function(self): + """Returns the TPU host place function.""" + master = self.master_job + def _placement_function(_sentinal=None, core_id=None, host_id=None): # pylint: disable=invalid-name + assert _sentinal is None + if core_id is not None and host_id is not None: + raise RuntimeError( + 'core_id and host_id can have only one non-None value.') + + if master is None: + return '/replica:0/task:0/device:CPU:0' + else: + # This assumes that if using more than 8 shards, + # the job configuration varies 'task'. + if core_id is not None: + host_id = core_id / 8 + return '/job:%s/task:%d/device:CPU:0' % (master, host_id) + return _placement_function + + @property + def tpu_device_placement_function(self): + master = self.master_job + job_device = '' if master is None else ('/job:%s' % master) + def _placement_function(i): + return '%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8) + return _placement_function + @property + def tpu_ordinal_function(self): + """Returns the TPU ordinal fn.""" + def _tpu_ordinal_function(index): + """Return the TPU ordinal associated with a shard. -def _is_running_on_cpu(use_tpu, mode, eval_batch_size): - """Determines whether the input_fn and model_fn should be invoked on CPU.""" - return ((not use_tpu) or mode == model_fn_lib.ModeKeys.PREDICT or - (mode == model_fn_lib.ModeKeys.EVAL and eval_batch_size is None)) + Required because the enqueue ops are placed on CPU. + Args: + index: the shard index -def _per_shard_batch_size(global_batch_size, run_config, use_tpu): - """Returns the batch size for each shard.""" - if use_tpu: - return global_batch_size // run_config.tpu_config.num_shards - else: - return global_batch_size + Returns: + The ordinal of the TPU device the shard's infeed should be placed on. + """ + return index % 8 + return _tpu_ordinal_function class _SIGNAL(object): @@ -268,17 +470,30 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController): def _input_thread_fn_for_loading(self, session, enqueue_ops): count = 0 - while True: - signal = self._signal_queue.get() - if signal == _SIGNAL.STOP: - logging.info('Stop Infeed input thread.') - return - - iterations = signal - for i in range(iterations): - logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) - session.run(enqueue_ops) - count += 1 + try: + while True: + signal = self._signal_queue.get() + if signal == _SIGNAL.STOP: + logging.info('Stop Infeed input thread.') + return + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + # Enqueue batches for next loop. + session.run(enqueue_ops) + else: + iterations = signal + for i in range(iterations): + logging.debug('Infeed enqueue for iteration (%d, %d)', count, i) + session.run(enqueue_ops) + count += 1 + + except Exception: # pylint: disable=broad-except + logging.error( + 'Failed running infeed, closing session.\n' + 'You may see an exception from your main session after this.', + exc_info=1 + ) + session.close() def join(self): logging.info('Waiting for Infeed Thread to exit.') @@ -294,17 +509,16 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): dequeue. """ - def __init__(self, run_config, mode, enqueue_fn, dequeue_ops=None): - self._tpu_job = _tpu_job(run_config, mode) - self._enqueue_fn = enqueue_fn + def __init__(self, ctx, enqueue_ops, dequeue_ops=None): + self._master_job = ctx.master_job + self._enqueue_ops = enqueue_ops self._dequeue_ops = dequeue_ops def begin(self): - self._enqueue_ops = self._enqueue_fn() - self._iterations_per_loop_var = _create_iterations_per_loop() - logging.info('TPU job name %s', self._tpu_job) - self._init_op = [tpu.initialize_system(job=self._tpu_job)] - self._finalize_op = [tpu.shutdown_system(job=self._tpu_job)] + logging.info('TPU job name %s', self._master_job) + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() + self._init_op = [tpu.initialize_system(job=self._master_job)] + self._finalize_op = [tpu.shutdown_system(job=self._master_job)] def after_create_session(self, session, coord): logging.info('Init TPU system') @@ -326,6 +540,7 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook): iterations = run_context.session.run(self._iterations_per_loop_var) self._infeed_thd_controller.send_next_batch_signal(iterations) if self._dequeue_ops is not None: + # TODO(xiejw): Refactor the outfeed dequeue into tf.while_loop. logging.info('Dequeue next batch of data from outfeed.') self._outfeed_thd_controller.send_next_batch_signal(iterations) @@ -387,7 +602,7 @@ class _TPUStopAtStepHook(session_run_hook.SessionRunHook): if self._global_step_tensor is None: raise RuntimeError('Global step should be created.') - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): global_step = session.run(self._global_step_tensor) @@ -422,360 +637,288 @@ class _SetEvalIterationsHook(session_run_hook.SessionRunHook): self._num_steps = num_steps def begin(self): - self._iterations_per_loop_var = _create_iterations_per_loop() + self._iterations_per_loop_var = _create_or_get_iterations_per_loop() def after_create_session(self, session, coord): self._iterations_per_loop_var.load(self._num_steps, session=session) -class _PerShardOutput(object): - """Wraps input_fn's outputs into per-shard outputs. - - Used so that the model_fn can distinguish between sharded input and unsharded - inputs (e.g., for export_savedmodel()). - """ - - def __init__(self, output): - self.output = output - - def as_list(self): - return self.output - +def generate_per_core_enqueue_ops_fn_for_host( + ctx, input_fn, inputs_structure_recorder): + """Generates infeed enqueue ops for per-core input_fn on a single host.""" + infeed_queue_holder = {'instance': None} + + def enqueue_ops_fn(): + """A fn returns enqueue_ops.""" + num_cores_per_host = ctx.num_of_cores_per_host + per_host_sharded_inputs = [] + for core_ordinal in range(num_cores_per_host): + with ops.name_scope('ordinal_%d' % (core_ordinal)): + inputs = input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None -class _InputsHolder(object): - """A inputs holder holds the `features` and `labels' for TPU system. + inputs_structure_recorder.validate_and_record_structure( + features, labels) + flattened_inputs = ( + inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + per_host_sharded_inputs.append(flattened_inputs) - Model inputs returned by the `input_fn` can have one of the following forms: + infeed_queue = tpu_feed.InfeedQueue( + number_of_tuple_elements=len(per_host_sharded_inputs[0])) + infeed_queue_holder['instance'] = infeed_queue + infeed_queue.set_configuration_from_sharded_input_tensors( + per_host_sharded_inputs) + + per_host_enqueue_ops = infeed_queue.generate_enqueue_ops( + per_host_sharded_inputs, + tpu_ordinal_function=ctx.tpu_ordinal_function) + return per_host_enqueue_ops + return enqueue_ops_fn, (lambda: infeed_queue_holder['instance']) + + +class _InputPipeline(object): + """`_InputPipeline` handles invoking `input_fn` and piping to infeed queue. + + `_InputPipeline` abstracts the per-core/per-host `input_fn` invocation from + call site. To be precise, based on the configuration in `_TPUContext`, it + invokes `input_fn` for all cores (usually multi-host TPU training) or for one + host (usually for single-host TPU evaluation), and sends all `features` and + `labels` returned by `input_fn` to TPU infeed. For per-core invocation, + `features` and `labels` are piped to infeed directly, one tuple for each + core. For per-host invocation, `features` and `labels` are split at host + (with respect to `batch_axis`) and piped to all cores accordingly. + + In addition, flatten/unflatten are handled by `_InputPipeline` also. Model + inputs returned by the `input_fn` can have one of the following forms: 1. features 2. (features, labels) Internally, form 1 is reformed to `(features, None)` as features and labels are passed separatedly to underlying methods. For TPU training, TPUEstimator - expects multiple `features` and `labels` tuples one for each shard. - - In addition, TPUEstimator allows various different structures for inputs - (namely `features` and `labels`). `features` can be `Tensor` or dict of - string name to `Tensor`, and `labels` could be `None`, `Tensor`, or dict of - string name to `Tensor`. TPU infeed/outfeed library expects flattened tensor - list. So, `features` and `labels` need to be flattened, before infeed enqueue, - and the structure of them needs to be recorded, in order to restore them after - infeed dequeue. - - `_InputsHolder` could hold the `features` and `labels` tuple for all shards - (usually multi-host TPU training) or for one host (usually for single-host TPU - evaluation), records the structure details (including presence, dict or single - tensor, dict names), validates the structure consistency cross all shards, and - encapsulates the flatten/unflatten logic. + may expect multiple `features` and `labels` tuples one for each core. + + TPUEstimator allows various different structures for inputs (namely `features` + and `labels`). `features` can be `Tensor` or dict of string name to `Tensor`, + and `labels` could be `None`, `Tensor`, or dict of string name to `Tensor`. + TPU infeed/outfeed library expects flattened tensor list. So, `features` and + `labels` need to be flattened, before infeed enqueue, and the structure of + them needs to be recorded, in order to restore them after infeed dequeue. """ - def __init__(self, features=None, labels=None, num_shards=None): - """Constructor. - - Args: - features: features for one host or a list of features one for each shard - (must be type `_PerShardOutput`). Once provided, the corresponding - `labels` should be set also and this `_InputsHolder` is frozen to - prevent from future modification. If `None`, it is expected to add - features and labels for each shard by calling `append_tuple` later. - labels: labels for one host or a list of labels one for each shard - (must be type `_PerShardOutput`). - num_shards: Number of shards in the TPU system. Must be provided unless it - can be deduced from `features`. - - Raises: - ValueError: If both `sharded_features` and `num_shards` are `None`. - """ - # Holds the features and labels for all shards. - self._feature_list = [] - self._label_list = [] - - # Holds the structure of inputs - self._feature_names = [] - self._label_names = [] - self._has_labels = False - - # Internal state. - self._initialized = False - self._frozen = False - self._sharded = False - - if features is None: - if num_shards is None: - raise ValueError( - '`features` and `num_shards` cannot be both None') - self._num_shards = num_shards - elif isinstance(features, _PerShardOutput): - self._from_sharded_inputs(features, labels, num_shards) - else: - if num_shards is None: - raise ValueError( - '`num_shards` cannot be None for unsharded features.') - self._from_unsharded_inputs(features, labels, num_shards) - - def _from_unsharded_inputs(self, features, labels, num_shards): - """Initializes the inputs with unsharded features and labels.""" - self._num_shards = num_shards - if labels is not None: - self._has_labels = True - self.append_tuple((features, labels)) - else: - self.append_tuple(features) - - self._sharded = False - self._frozen = True - - def _from_sharded_inputs(self, sharded_features, sharded_labels, num_shards): - """Initializes the inputs with sharded features and labels.""" - if not isinstance(sharded_features, _PerShardOutput): - raise ValueError('`sharded_features` must have type `_PerShardOutput`.') - features = sharded_features.as_list() - - if num_shards is not None and num_shards != len(features): - raise ValueError( - '`num_shards` should be same as the length of sharded_features.') + class InputsStructureRecorder(object): + """The recorder to record inputs structure.""" + + def __init__(self): + # Holds the structure of inputs + self._feature_names = [] + self._label_names = [] + self._has_labels = False + + # Internal state. + self._initialized = False + + def has_labels(self): + return self._has_labels + + def validate_and_record_structure(self, features, labels): + """Validates and records the structure of features` and `labels`.""" + def _extract_key_names(tensor_or_dict): + if tensor_or_dict is None: + return [] + return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] + + # Extract structure. + has_labels = labels is not None + feature_names = _extract_key_names(features) + label_names = _extract_key_names(labels) + + if self._initialized: + # Verify the structure is same. The following should never happen. + assert feature_names == self._feature_names, 'feature keys mismatched' + assert label_names == self._label_names, 'label keys mismatched' + assert has_labels == self._has_labels, 'label presence mismatched' + else: + # Record structure. + self._initialized = True + self._feature_names = feature_names + self._label_names = label_names + self._has_labels = has_labels + + def flatten_features_and_labels(self, features, labels): + """Flattens the `features` and `labels` to a single tensor list.""" + flattened_inputs = [] + if self._feature_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([features[name] + for name in self._feature_names]) + else: + flattened_inputs.append(features) - self._num_shards = len(features) - if not self._num_shards: - raise ValueError('`sharded_features` should not be empty.') + if labels is not None: + if self._label_names: + # We need a fixed ordering for enqueueing and dequeueing. + flattened_inputs.extend([labels[name] for name in self._label_names]) + else: + flattened_inputs.append(labels) + return flattened_inputs + + def unflatten_features_and_labels(self, flattened_inputs): + """Restores the flattened inputs to original features and labels form. + + Args: + flattened_inputs: Flattened inputs for each shard. + + Returns: + A tuple of (`features`, `labels`), where `labels` could be None. + Each one, if present, should have identical structure (single tensor vs + dict) as the one returned by input_fn. + + Raises: + ValueError: If the number of expected tensors from `flattened_inputs` + mismatches the recorded structure. + """ + expected_num_features = (len(self._feature_names) if self._feature_names + else 1) + if self._has_labels: + expected_num_labels = (len(self._label_names) if self._label_names + else 1) + else: + expected_num_labels = 0 - if sharded_labels is not None: - if not isinstance(sharded_labels, _PerShardOutput): - raise ValueError('sharded_labels` must have type `_PerShardOutput`.') + expected_num_tensors = expected_num_features + expected_num_labels - self._has_labels = True - labels = sharded_labels.as_list() - if self._num_shards != len(labels): + if expected_num_tensors != len(flattened_inputs): raise ValueError( - 'Length of `sharded_features` and `sharded_labels` mismatch.') - - if self._has_labels: - for (f, l) in zip(features, labels): - self.append_tuple((f, l)) - else: - for f in features: - self.append_tuple(f) - - self._sharded = True - self._frozen = True - - def _extract_key_names(self, tensor_or_dict): - if tensor_or_dict is None: - return [] - - return tensor_or_dict.keys() if isinstance(tensor_or_dict, dict) else [] - - def _validate(self, features, labels): - has_labels = labels is not None - feature_names = self._extract_key_names(features) - label_names = self._extract_key_names(labels) - - if self._initialized: - self._sharded = True - # The following should never happen. - assert feature_names == self._feature_names, 'feature keys mismatched' - assert label_names == self._label_names, 'label keys mismatched' - assert has_labels == self._has_labels, 'label presence mismatched' - else: - self._initialized = True - self._feature_names = feature_names - self._label_names = label_names - self._has_labels = has_labels - - @property - def sharded(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._sharded - - @property - def num_shards(self): - if not self._frozen: - raise RuntimeError('_InputsHolder has not been frozen yet.') - return self._num_shards - - def append_tuple(self, inputs): - """Appends `inputs` for one shard into holder. - - Args: - inputs: The return from `input_fn`, which could be features or tuple of - (features, labels). After the first `inputs` appended into - `_InputsHolder`, the structure of `features` and `labels is recorded. - Any future invocation should provide the `inputs` with same structure. - - Raises: - RuntimeError: If the internal data has been frozen already. - """ - if self._frozen: - raise RuntimeError('InputsHolder has frozen, which cannot be mutated.') - - # input_fn may return either features or (features, labels) - if isinstance(inputs, tuple): - features, labels = inputs - else: - features, labels = inputs, None - - self._validate(features, labels) - - self._feature_list.append(features) - if labels is not None: - self._label_list.append(labels) - - def as_features_and_labels_tuple(self): - """Returns features and labels as grouped tuple. - - This is intended to be used to pass features and labels for all shards from - input_fn to model_fn as the parent class `Estimator` does not have the - concept of shards. So, grouped tuple is required. - - Once called, the internal data is frozen and `append_tuple` cannot be - invoked anymore. - - Returns: - A tuple of features and labels. Both have type `_PerShardOutput`, holding - the inputs for all shards. `labels` could be `None`. - - Raises: - RuntimeError: If the internal data has not been initialized. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - assert len(self._feature_list) == self._num_shards - if not self._label_list or all(l is None for l in self._label_list): - return _PerShardOutput(self._feature_list), None - - assert len(self._label_list) == self._num_shards - return (_PerShardOutput(self._feature_list), - _PerShardOutput(self._label_list)) - - def as_sharded_flattened_inputs(self): - """Flatten the features and label as tensor lists for all shards. - - Flattened tensor list contains all tensors in `features` (dict) and `labels` - (dict). Conceptually, it has the predicated structure like: - - ```python - flatten_list = [] - for name in features: - flatten_list.append(features[name]) - for name in labels: - flatten_list.append(labels[name]) - ``` - - This method handles the label is None case and single tensor case nicely. - - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. - - Returns: - A list of flattened inputs one for each shard. - - Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the inputs are sharded. - """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if not self._sharded: - raise ValueError('Inputs are not sharded.') - - sharded_inputs = [] - - for shard in range(self._num_shards): - flattened_inputs = self._as_flattened_inputs( - self._feature_list[shard], - self._label_list[shard] if self._has_labels else None) - sharded_inputs.append(flattened_inputs) - - return sharded_inputs - - def as_flattened_inputs(self): - """Flatten the features and label as a single tensor list for one host.""" - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - if self._sharded: - raise ValueError('Inputs are sharded.') - - return self._as_flattened_inputs( - self._feature_list[0], - self._label_list[0] if self._has_labels else None) - - def _as_flattened_inputs(self, features, labels): - """Flattens the `features` and `labels` to a single tensor list.""" - flattened_inputs = [] - if self._feature_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([features[name] for name in self._feature_names]) - else: - flattened_inputs.append(features) - - if labels is not None: - if self._label_names: - # We need a fixed ordering for enqueueing and dequeueing. - flattened_inputs.extend([labels[name] for name in self._label_names]) + 'The number of flattened tensors mismatches expected num. ' + 'Expected {}, got {}'.format(expected_num_tensors, + len(flattened_inputs))) + if self._feature_names: + unflattened_features = dict( + zip(self._feature_names, flattened_inputs[:expected_num_features])) else: - flattened_inputs.append(labels) - return flattened_inputs + # Single tensor case + unflattened_features = flattened_inputs[0] + + if expected_num_labels == 0: + unflattened_label = None + elif self._label_names: + unflattened_label = dict(zip(self._label_names, + flattened_inputs[expected_num_features:])) + else: + # Single tensor case. + unflattened_label = flattened_inputs[expected_num_features] - def unflatten_features_and_labels(self, flattened_inputs): - """Restores the flattened inputs to original features and labels form. + return unflattened_features, unflattened_label - Once called, the internal data is frozen and `append_tuple` cannot be - invokded anymore. + def __init__(self, input_fn, batch_axis, ctx): + """Constructor. Args: - flattened_inputs: Flattened inputs for one each, which should be created - by the `as_sharded_flattened_inputs` API. - - Returns: - A tuple of (`features`, `labels`), where `labels` could be None. - Each one, if present, should have identical structure (single tensor vs - dict) as the one returned by input_fn. + input_fn: input fn for train or eval. + batch_axis: A python tuple of int values describing how each tensor + produced by the Estimator `input_fn` should be split across the TPU + compute shards. + ctx: A `_TPUContext` instance with mode. Raises: - RuntimeError: If the internal data has not been initialized. - ValueError: If the number of expected tensors from `flattened_inputs` - mismatches the recorded structure. + ValueError: If both `sharded_features` and `num_cores` are `None`. """ - self._frozen = True - if not self._initialized: - raise RuntimeError('InputsHolder has not been initialized.') - - expected_num_features = (len(self._feature_names) if self._feature_names - else 1) - if self._has_labels: - expected_num_labels = (len(self._label_names) if self._label_names - else 1) - else: - expected_num_labels = 0 + self._inputs_structure_recorder = _InputPipeline.InputsStructureRecorder() + + self._sharded_per_core = ctx.is_input_sharded_per_core() + self._input_fn = input_fn + self._infeed_queue = None + self._ctx = ctx + self._batch_axis = batch_axis + + def generate_infeed_enqueue_ops_and_dequeue_fn(self): + """Generates infeed enqueue ops and dequeue_fn.""" + # While tf.while_loop is called, the body function, which invokes + # `enqueue_fn` passed in, is called to construct the graph. So, input_fn + # structure is recorded. + enqueue_ops = self._invoke_input_fn_and_record_structure() + + def dequeue_fn(): + """dequeue_fn is used by TPU to retrieve the tensors.""" + values = self._infeed_queue.generate_dequeue_op() + # The unflatten process uses the structure information recorded above. + return self._inputs_structure_recorder.unflatten_features_and_labels( + values) + + return (enqueue_ops, dequeue_fn) + + def _invoke_input_fn_and_record_structure(self): + if self._sharded_per_core: + # Per-Core input pipeline deployment. + tpu_host_placement_fn = self._ctx.tpu_host_placement_function + enqueue_ops = [] + infeed_queues = [] + + # Invoke input pipeline for each core and placed on the corresponding + # host. + num_hosts = self._ctx.num_hosts + for host_id in range(num_hosts): + host_device = tpu_host_placement_fn(host_id=host_id) + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + enqueue_ops_fn, infeed_queue_getter = ( + generate_per_core_enqueue_ops_fn_for_host( + self._ctx, self._input_fn, self._inputs_structure_recorder)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + enqueue_ops.append(_wrap_computation_in_while_loop( + device=host_device, op_fn=enqueue_ops_fn)) + else: + enqueue_ops.append(enqueue_ops_fn()) + # Infeed_queue_getter must be called after enqueue_ops_fn is called. + infeed_queues.append(infeed_queue_getter()) + + # infeed_queue is used to generate dequeue ops. The only thing it uses for + # dequeue is dtypes and types. So, any one can be used. Here, grab the + # first one. + self._infeed_queue = infeed_queues[0] + return enqueue_ops - expected_num_tensors = expected_num_features + expected_num_labels - - if expected_num_tensors != len(flattened_inputs): - raise ValueError( - 'The number of flattened tensors mismatches expected num. ' - 'Expected {}, got {}'.format(expected_num_tensors, - len(flattened_inputs))) - if self._feature_names: - unflattened_features = dict(zip(self._feature_names, - flattened_inputs[:expected_num_features])) else: - # Single tensor case - unflattened_features = flattened_inputs[0] - - if expected_num_labels == 0: - unflattened_label = None - elif self._label_names: - unflattened_label = dict(zip(self._label_names, - flattened_inputs[expected_num_features:])) - else: - # Single tensor case. - unflattened_label = flattened_inputs[expected_num_features] - - return unflattened_features, unflattened_label + # TODO(b/67051042): Extend this to multi-host support. + host_id = 0 + host_device = self._ctx.tpu_host_placement_function(host_id=host_id) + def enqueue_fn(): + with ops.device(host_device): + with ops.name_scope('input_pipeline_task%d' % (host_id)): + inputs = self._input_fn() + if isinstance(inputs, tuple): + features, labels = inputs + else: + features, labels = inputs, None + self._inputs_structure_recorder.validate_and_record_structure( + features, labels) + unsharded_tensor_list = ( + self._inputs_structure_recorder.flatten_features_and_labels( + features, labels)) + + self._infeed_queue = tpu_feed.InfeedQueue( + tuple_types=[t.dtype for t in unsharded_tensor_list], + tuple_shapes=[t.shape for t in unsharded_tensor_list], + shard_dimensions=self._batch_axis) + self._infeed_queue.set_number_of_shards(self._ctx.num_cores) + + def placement_fn(core_id): + return self._ctx.tpu_host_placement_function(core_id=core_id) + return ( + self._infeed_queue.split_inputs_and_generate_enqueue_ops( + unsharded_tensor_list, + placement_function=placement_fn)) + + if _WRAP_INPUT_FN_INTO_WHILE_LOOP: + return _wrap_computation_in_while_loop(device=host_device, + op_fn=enqueue_fn) + else: + return enqueue_fn() class _ModelFnWrapper(object): @@ -788,20 +931,17 @@ class _ModelFnWrapper(object): train and eval step. """ - def __init__(self, model_fn, config, params, mode, train_batch_size, - eval_batch_size): + def __init__(self, model_fn, config, params, ctx): self._model_fn = model_fn self._config = config self._params = params - self._mode = mode - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size + self._ctx = ctx def call_without_tpu(self, features, labels): # Let CrossShardOptimizer be called without TPU in model_fn, since it's # common to set the train_op even when running evaluate() or predict(). with tpu_function.tpu_shard_context(1): - return self._call_model_fn(features, labels, use_tpu=False) + return self._call_model_fn(features, labels) def convert_to_single_tpu_train_step(self, dequeue_fn): """Converts user provided model_fn` as a single train step on TPU. @@ -831,7 +971,7 @@ class _ModelFnWrapper(object): features, labels = dequeue_fn() estimator_spec = self._verify_estimator_spec( - self._call_model_fn(features, labels, use_tpu=True)) + self._call_model_fn(features, labels)) loss, train_op = estimator_spec.loss, estimator_spec.train_op with ops.control_dependencies([train_op]): return array_ops.identity(loss) @@ -863,13 +1003,13 @@ class _ModelFnWrapper(object): A tuple of eval_fn and eval_metrics. The eval_fn representing the eval step for TPU. and eval_metrics is an `_EvalMetrics` instance. """ - eval_metrics = _EvalMetrics() + eval_metrics = _EvalMetrics(self._ctx) def eval_step(total_loss): """Evaluation step function for use inside a while loop.""" features, labels = dequeue_fn() - tpu_estimator_spec = self._call_model_fn(features, labels, use_tpu=True) + tpu_estimator_spec = self._call_model_fn(features, labels) if not isinstance(tpu_estimator_spec, TPUEstimatorSpec): raise RuntimeError( 'estimator_spec used by TPU evaluation must have type' @@ -883,11 +1023,7 @@ class _ModelFnWrapper(object): return math_ops.add(total_loss, loss) return eval_step, eval_metrics - @property - def config(self): - return self._config - - def _call_model_fn(self, features, labels, use_tpu): + def _call_model_fn(self, features, labels): """Calls the model_fn with required parameters.""" model_fn_args = util.fn_args(self._model_fn) kwargs = {} @@ -898,12 +1034,11 @@ class _ModelFnWrapper(object): if 'labels' in model_fn_args: kwargs['labels'] = labels - else: - if labels is not None: - raise ValueError( - 'model_fn does not take labels, but input_fn returns labels.') + elif labels is not None: + raise ValueError( + 'model_fn does not take labels, but input_fn returns labels.') if 'mode' in model_fn_args: - kwargs['mode'] = self._mode + kwargs['mode'] = self._ctx.mode if 'config' in model_fn_args: kwargs['config'] = config if 'params' in model_fn_args: @@ -914,16 +1049,16 @@ class _ModelFnWrapper(object): 'model_fn ({}) does not include params argument, ' 'required by TPUEstimator to pass batch size as ' 'params[\'batch_size\']'.format(self._model_fn)) - if self._mode == model_fn_lib.ModeKeys.TRAIN: - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._train_batch_size, config, use_tpu) - elif (self._mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - params[_BATCH_SIZE_KEY] = _per_shard_batch_size( - self._eval_batch_size, config, use_tpu) + + batch_size_for_model_fn = self._ctx.batch_size_for_model_fn + if batch_size_for_model_fn is not None: + params[_BATCH_SIZE_KEY] = batch_size_for_model_fn estimator_spec = self._model_fn(features=features, **kwargs) - if (not use_tpu) and isinstance(estimator_spec, TPUEstimatorSpec): + if (self._ctx.is_running_on_cpu() and + isinstance(estimator_spec, TPUEstimatorSpec)): + # The estimator_spec will be passed to `Estimator` directly, which expects + # type `EstimatorSpec`. return estimator_spec.as_estimator_spec() else: return estimator_spec @@ -946,7 +1081,8 @@ class _ModelFnWrapper(object): class _EvalMetrics(object): """Class wraps TPUEstimator.eval_metrics.""" - def __init__(self): + def __init__(self, ctx): + self._ctx = ctx self._metric_fn = None self._is_dict = False self._tensor_keys = [] @@ -970,8 +1106,6 @@ class _EvalMetrics(object): if isinstance(eval_metrics[1], (tuple, list)): fn_args = util.fn_args(eval_metrics[0]) - if 'self' in fn_args: - fn_args = tuple([arg for arg in fn_args if arg != 'self']) if len(eval_metrics[1]) != len(fn_args): raise RuntimeError( 'In TPUEstimatorSpec.eval_metrics, length of tensors does not ' @@ -1029,7 +1163,7 @@ class _EvalMetrics(object): raise RuntimeError('Eval metrics have not been recorded yet') return self._tensors - def to_metric_metric_ops_for_tpu(self, run_config, dummy_update_op): + def to_metric_metric_ops_for_tpu(self, dummy_update_op): """Creates the eval_metric_ops now based on the TPU outfeed. `eval_metric_ops` is defined in `EstimatorSpec`. From all shards, tensors @@ -1038,7 +1172,6 @@ class _EvalMetrics(object): metric fn. Args: - run_config: A `RunConfig` instance. dummy_update_op: A dummy update op. Returns: @@ -1050,9 +1183,7 @@ class _EvalMetrics(object): RuntimeError: If outfeed tensor is scalar. """ - num_shards = run_config.tpu_config.num_shards - job = _tpu_job(run_config, model_fn_lib.ModeKeys.EVAL) - job_device = '' if job is None else ('/job:%s' % job) + num_cores = self._ctx.num_cores # For each i, dequeue_ops[i] is a list containing the tensors from all # shards. This list is concatenated later. @@ -1061,8 +1192,9 @@ class _EvalMetrics(object): dequeue_ops.append([]) # Outfeed ops execute on each JF node. - for i in xrange(num_shards): - with ops.device('%s/task:%d/device:TPU:%d' % (job_device, i / 8, i % 8)): + tpu_device_placement_fn = self._ctx.tpu_device_placement_function + for i in xrange(num_cores): + with ops.device(tpu_device_placement_fn(i)): outfeed_tensors = tpu_ops.outfeed_dequeue_tuple( dtypes=self._tensor_dtypes, shapes=self._tensor_shapes) for j, item in enumerate(outfeed_tensors): @@ -1070,7 +1202,7 @@ class _EvalMetrics(object): # It is assumed evaluation always happends on single host TPU system. So, # place all ops on tpu host if possible. - with ops.device('{}/device:CPU:0'.format(job_device)): + with ops.device(self._ctx.tpu_host_placement_function(core_id=0)): for i, item in enumerate(dequeue_ops): if dequeue_ops[i][0].shape.ndims == 0: raise RuntimeError( @@ -1115,9 +1247,9 @@ class TPUEstimator(estimator_lib.Estimator): specify `train_batch_size` in constructor, and then get the batch size for each shard in `input_fn` and `model_fn` by `params['batch_size']`. If `TPUConfig.per_host_input_for_training` is `True`, `input_fn` is invoked per - host rather than per shard. In this case, a global batch size is transformed a + host rather than per core. In this case, a global batch size is transformed a per-host batch size in params for `input_fn`, but `model_fn` still gets - per-shard batch size. + per-core batch size. For evaluation, if `eval_batch_size` is None, it is executed on CPU, even if `use_tpu` is `True`. If `eval_batch_size` is not `None`, it is executed on @@ -1264,14 +1396,18 @@ class TPUEstimator(estimator_lib.Estimator): 'eval batch size {} must be divisible by number of shards {}' .format(eval_batch_size, config.tpu_config.num_shards)) + if (config.tpu_config.num_shards > 8 and + config.tpu_config.per_host_input_for_training): + # TODO(b/67051042): Support per_host input pipelines when num_shards > 8 + raise NotImplementedError( + 'Per-host input pipelines only available for num_shards <= 8') + # Verifies the model_fn signature according to Estimator framework. estimator_lib._verify_model_fn_args(model_fn, params) # pylint: disable=protected-access # We cannot store config and params in this constructor as parent # constructor might change them, such as assigning a temp dir for # config.model_dir. - model_function = _augment_model_fn(model_fn, train_batch_size, - eval_batch_size, use_tpu, - batch_axis) + model_function = self._augment_model_fn(model_fn, batch_axis) # Passing non-None params as wrapped model_fn has it. params = params or {} @@ -1280,12 +1416,13 @@ class TPUEstimator(estimator_lib.Estimator): model_dir=model_dir, config=config, params=params) - self._use_tpu = use_tpu - self._train_batch_size = train_batch_size - self._eval_batch_size = eval_batch_size self._iterations_per_training_loop = ( self._config.tpu_config.iterations_per_loop) + # All properties passed to _TPUContext are immutable. + self._ctx = _TPUContext(self._config, train_batch_size, eval_batch_size, + use_tpu) + def _create_global_step(self, graph): """Creates a global step suitable for TPUs. @@ -1301,10 +1438,10 @@ class TPUEstimator(estimator_lib.Estimator): return _create_global_step(graph) def _convert_train_steps_to_hooks(self, steps, max_steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.TRAIN, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_train_steps_to_hooks( - steps, max_steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.TRAIN) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_train_steps_to_hooks( + steps, max_steps) # On TPU. if steps is None and max_steps is None: @@ -1312,18 +1449,24 @@ class TPUEstimator(estimator_lib.Estimator): 'For TPU training, one of `steps` or `max_steps` must be set. ' 'Cannot be both `None`.') + # Estimator.train has explicit positiveness check. + if steps is not None: + util_lib.check_positive_integer(steps, 'Train steps') + if max_steps is not None: + util_lib.check_positive_integer(max_steps, 'Train max_steps') + return [_TPUStopAtStepHook(self._iterations_per_training_loop, steps, max_steps)] def _convert_eval_steps_to_hooks(self, steps): - if _is_running_on_cpu(self._use_tpu, model_fn_lib.ModeKeys.EVAL, - self._eval_batch_size): - return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) + with self._ctx.with_mode(model_fn_lib.ModeKeys.EVAL) as ctx: + if ctx.is_running_on_cpu(): + return super(TPUEstimator, self)._convert_eval_steps_to_hooks(steps) if steps is None: raise ValueError('Evaluate `steps` must be set on TPU. Cannot be `None`.') - if steps <= 0: - raise ValueError('Must specify steps > 0, given: {}'.format(steps)) + + util_lib.check_positive_integer(steps, 'Eval steps') hooks = [] hooks.append(evaluation._StopAfterNEvalsHook( # pylint: disable=protected-access @@ -1358,197 +1501,115 @@ class TPUEstimator(estimator_lib.Estimator): if 'config' in input_fn_args: kwargs['config'] = config - # Setting the batch size in params first. This helps user to have same - # input_fn for use_tpu=True/False. - if mode == model_fn_lib.ModeKeys.TRAIN: - kwargs['params'][_BATCH_SIZE_KEY] = ( - _per_shard_batch_size(self._train_batch_size, config, self._use_tpu) - if not config.tpu_config.per_host_input_for_training else - self._train_batch_size) - elif (mode == model_fn_lib.ModeKeys.EVAL and - self._eval_batch_size is not None): - # For TPU evaluation, input_fn is invoked for one host (instead of shard). - kwargs['params'][_BATCH_SIZE_KEY] = self._eval_batch_size - - if _is_running_on_cpu(self._use_tpu, mode, self._eval_batch_size): - with ops.device('/device:CPU:0'): - return input_fn(**kwargs) - - job = _tpu_job(config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) + with self._ctx.with_mode(mode) as ctx: + # Setting the batch size in params first. This helps user to have same + # input_fn for use_tpu=True/False. + batch_size_for_input_fn = ctx.batch_size_for_input_fn + if batch_size_for_input_fn is not None: + kwargs['params'][_BATCH_SIZE_KEY] = batch_size_for_input_fn - if mode == model_fn_lib.ModeKeys.TRAIN: - if not config.tpu_config.per_host_input_for_training: - # Now for TPU training. - num_shards = config.tpu_config.num_shards - inputs = _InputsHolder(num_shards=num_shards) - for i in range(config.tpu_config.num_shards): - with ops.device(placement_function(i)): - inputs.append_tuple(input_fn(**kwargs)) - return inputs.as_features_and_labels_tuple() - else: - # TODO(xiejw): Extend this to multi-host support. - with ops.device(placement_function(0)): + if ctx.is_running_on_cpu(): + with ops.device('/device:CPU:0'): return input_fn(**kwargs) - # Now for TPU evaluation. - with ops.device(placement_function(0)): - return input_fn(**kwargs) - - -# TODO(b/64607814): Ensure batch_axis works with nested structures. -def _create_infeed_enqueue_ops_and_dequeue_fn(inputs_holder, run_config, - batch_axis, mode): - """Utility to convert input_fn to enqueue and dequeue fns for TPU. - - Args: - inputs_holder: An `_InputsHolder` holding features and labels. - run_config: A `RunConfig` instance. - batch_axis: A python list of batch dimensions. - mode: ModeKeys - - Returns: - A tuple of (dequeue_fn, enqueue_fn) - """ - if inputs_holder.sharded: - sharded_inputs = inputs_holder.as_sharded_flattened_inputs() - - infeed_queue = tpu_feed.InfeedQueue( - number_of_tuple_elements=len(sharded_inputs[0])) - infeed_queue.set_configuration_from_sharded_input_tensors(sharded_inputs) - else: - unsharded_inputs = inputs_holder.as_flattened_inputs() - infeed_queue = tpu_feed.InfeedQueue( - tuple_types=[t.dtype for t in unsharded_inputs], - tuple_shapes=[t.shape for t in unsharded_inputs], - shard_dimensions=batch_axis) - infeed_queue.set_number_of_shards(inputs_holder.num_shards) - - def dequeue_fn(): - """dequeue_fn is used by the train_step in TPU to retrieve the tensors.""" - values = infeed_queue.generate_dequeue_op() - return inputs_holder.unflatten_features_and_labels(values) - - def tpu_ordinal_function(index): - """Return the TPU ordinal associated with a shard. - - Required because the enqueue ops are placed on CPU. - - Args: - index: the shard index - - Returns: - The ordinal of the TPU device the shard's infeed should be placed on. - """ - return index % 8 - - def enqueue_fn(): - """enqueue_fn is used to add ops to the graph to send tensors.""" - if inputs_holder.sharded: - return infeed_queue.generate_enqueue_ops( - sharded_inputs, tpu_ordinal_function=tpu_ordinal_function) - else: - job = _tpu_job(run_config, mode) - def placement_function(index): - if job is None: - return '/replica:0/task:0/device:CPU:0' - else: - # This assumes that if using more than 8 shards, - # the job configuration varies 'task'. - return '/job:%s/task:%d/device:CPU:0' % (job, index / 8) - return infeed_queue.split_inputs_and_generate_enqueue_ops( - unsharded_inputs, placement_function=placement_function) - - return (dequeue_fn, enqueue_fn) - - -def _augment_model_fn(model_fn, train_batch_size, eval_batch_size, use_tpu, - batch_axis): - """Returns a new model_fn, which wraps the TPU support.""" - - def _model_fn(features, labels, mode, config, params): - """A Estimator `model_fn` for TPUEstimator.""" - model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, mode, - train_batch_size, eval_batch_size) - - # TODO(jhseu): Move to PREDICT to TPU. - if _is_running_on_cpu(use_tpu, mode, eval_batch_size): - logging.info('Running %s on CPU', mode) - return model_fn_wrapper.call_without_tpu(features, labels) - - inputs = _InputsHolder(features=features, labels=labels, - num_shards=config.tpu_config.num_shards) - - dequeue_fn, enqueue_fn = _create_infeed_enqueue_ops_and_dequeue_fn( - inputs, config, batch_axis, mode) - - if mode == model_fn_lib.ModeKeys.TRAIN: - loss = _train_on_tpu_system(model_fn_wrapper, dequeue_fn) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn), - training.LoggingTensorHook( - {'loss': array_ops.identity(loss), - 'step': training.get_global_step()}, - every_n_secs=30) - ] - summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) - with ops.control_dependencies([loss]): - update_ops = _sync_variables_ops() - - # Validate the TPU training graph to catch basic errors - _validate_tpu_training_graph() - - return model_fn_lib.EstimatorSpec( - mode, - loss=loss, - training_hooks=hooks, - train_op=control_flow_ops.group(*update_ops)) - - # Now eval. - total_loss, eval_metric_ops = _eval_on_tpu_system( - model_fn_wrapper, dequeue_fn) - iterations_per_loop_var = _create_iterations_per_loop() - mean_loss = math_ops.div( - total_loss, - math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) - - # Creates a dummy metric update_op for all metrics. Estimator expects all - # metrics in eval_metric_ops have update_op and calls them one by one. The - # real metric update_ops are invoked in a separated thread. So, here give - # Estimator the dummy op for all metrics. - with ops.control_dependencies([mean_loss]): - # After TPU evaluation computation is done (the mean_loss tensor), reads - # all variables back from TPU and updates the eval step counter properly. - internal_ops_to_run = _sync_variables_ops() - internal_ops_to_run.append( - _increase_eval_step_op(iterations_per_loop_var)) - with ops.control_dependencies(internal_ops_to_run): - dummy_update_op = control_flow_ops.no_op() - - eval_metric_ops, eval_update_ops = ( - eval_metric_ops.to_metric_metric_ops_for_tpu( - config, dummy_update_op)) - hooks = [ - TPUInfeedOutfeedSessionHook(config, mode, enqueue_fn, eval_update_ops), - ] - - return model_fn_lib.EstimatorSpec( - mode, - loss=mean_loss, - evaluation_hooks=hooks, - eval_metric_ops=eval_metric_ops) - return _model_fn - - -def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): + # For TPU computation, input_fn should be invoked in a tf.while_loop for + # performance. While constructing the tf.while_loop, the structure of + # inputs returned by the `input_fn` needs to be recorded. The structure + # includes whether features or labels is dict or single Tensor, dict keys, + # tensor shapes, and dtypes. The recorded structure is used to create the + # infeed dequeue ops, which must be wrapped and passed as a Fn, called + # inside the TPU computation, as the TPU computation is wrapped inside a + # tf.while_loop also. So, we either pass input_fn to model_fn or pass + # dequeue_fn to model_fn. Here, `input_fn` is passed directly as + # `features` in `model_fn` signature. + def _input_fn(): + return input_fn(**kwargs) + return _input_fn + + def _augment_model_fn(self, model_fn, batch_axis): + """Returns a new model_fn, which wraps the TPU support.""" + + def _model_fn(features, labels, mode, config, params): + """A Estimator `model_fn` for TPUEstimator.""" + with self._ctx.with_mode(mode) as ctx: + model_fn_wrapper = _ModelFnWrapper(model_fn, config, params, ctx) + + # TODO(jhseu): Move to PREDICT to TPU. + if ctx.is_running_on_cpu(): + logging.info('Running %s on CPU', mode) + return model_fn_wrapper.call_without_tpu(features, labels) + + assert labels is None, '`labels` passed to `model_fn` must be `None`.' + # TPUEstimator._call_input_fn passes `input_fn` as features to here. + assert callable(features), '`input_fn` is not callable.' + input_fn = features + + input_holders = _InputPipeline(input_fn, batch_axis, ctx) + enqueue_ops, dequeue_fn = ( + input_holders.generate_infeed_enqueue_ops_and_dequeue_fn()) + + if mode == model_fn_lib.ModeKeys.TRAIN: + loss = _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops), + training.LoggingTensorHook( + {'loss': array_ops.identity(loss), + 'step': training.get_global_step()}, + every_n_secs=30) + ] + summary.scalar(model_fn_lib.LOSS_METRIC_KEY, loss) + with ops.control_dependencies([loss]): + update_ops = _sync_variables_ops() + + # Validate the TPU training graph to catch basic errors + _validate_tpu_training_graph() + + return model_fn_lib.EstimatorSpec( + mode, + loss=loss, + training_hooks=hooks, + train_op=control_flow_ops.group(*update_ops)) + + # Now eval. + total_loss, eval_metric_ops = _eval_on_tpu_system( + ctx, model_fn_wrapper, dequeue_fn) + iterations_per_loop_var = _create_or_get_iterations_per_loop() + mean_loss = math_ops.div( + total_loss, + math_ops.cast(iterations_per_loop_var, dtype=total_loss.dtype)) + + # Creates a dummy metric update_op for all metrics. Estimator expects + # all metrics in eval_metric_ops have update_op and calls them one by + # one. The real metric update_ops are invoked in a separated thread. So, + # here give Estimator the dummy op for all metrics. + with ops.control_dependencies([mean_loss]): + # After TPU evaluation computation is done (the mean_loss tensor), + # reads all variables back from TPU and updates the eval step counter + # properly + internal_ops_to_run = _sync_variables_ops() + internal_ops_to_run.append( + _increase_eval_step_op(iterations_per_loop_var)) + with ops.control_dependencies(internal_ops_to_run): + dummy_update_op = control_flow_ops.no_op() + + eval_metric_ops, eval_update_ops = ( + eval_metric_ops.to_metric_metric_ops_for_tpu(dummy_update_op)) + hooks = [ + TPUInfeedOutfeedSessionHook(ctx, enqueue_ops, eval_update_ops), + ] + + return model_fn_lib.EstimatorSpec( + mode, + loss=mean_loss, + evaluation_hooks=hooks, + eval_metric_ops=eval_metric_ops) + return _model_fn + + +def _eval_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - config = model_fn_wrapper.config.tpu_config - num_shards = config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_eval_step, eval_metric_ops = ( model_fn_wrapper.convert_to_single_tpu_eval_step(dequeue_fn)) @@ -1561,15 +1622,15 @@ def _eval_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_eval_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss, eval_metric_ops -def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): +def _train_on_tpu_system(ctx, model_fn_wrapper, dequeue_fn): """Executes `model_fn_wrapper` multiple times on all TPU shards.""" - num_shards = model_fn_wrapper.config.tpu_config.num_shards - iterations_per_loop_var = _create_iterations_per_loop() + num_cores = ctx.num_cores + iterations_per_loop_var = _create_or_get_iterations_per_loop() single_tpu_train_step = model_fn_wrapper.convert_to_single_tpu_train_step( dequeue_fn) @@ -1583,11 +1644,27 @@ def _train_on_tpu_system(model_fn_wrapper, dequeue_fn): (loss,) = tpu.shard(multi_tpu_train_steps_on_single_shard, inputs=[], - num_shards=num_shards, + num_shards=num_cores, outputs_from_all_shards=False) return loss +def _wrap_computation_in_while_loop(device, op_fn): + """Wraps the ops generated by `op_fn` in tf.while_loop.""" + def computation(i): + with ops.control_dependencies(op_fn()): + return i + 1 + + iterations_per_loop_var = _create_or_get_iterations_per_loop() + # By setting parallel_iterations=1, the parallel execution in while_loop is + # basically turned off. + with ops.device(device): + iterations = array_ops.identity(iterations_per_loop_var) + return control_flow_ops.while_loop( + lambda i: i < iterations, + computation, [constant_op.constant(0)], parallel_iterations=1) + + def _validate_tpu_training_graph(): """Validate graph before running distributed training. @@ -1603,3 +1680,5 @@ def _validate_tpu_training_graph(): if not cross_replica_sum_ops: raise ValueError( 'CrossShardOptimizer must be used for model training on TPUs.') + + diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py index d545a94ca6a2fdb3a9df2748b59300fd141dc55d..f8ba7d45e20b2f48e1409427665878df40a6db02 100644 --- a/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py +++ b/tensorflow/contrib/tpu/python/tpu/tpu_sharding.py @@ -177,6 +177,10 @@ class ShardingPolicy(object): raise ValueError("shape %s does not contain shard_dimension %d" % (shape.as_list(), self._shard_dimension)) dims = shape.as_list() + if dims[self._shard_dimension] is None: + raise ValueError("shape %s must have a fixed size for dimension %d " + "that is known at graph construction time." % + (shape.as_list(), self._shard_dimension)) if (dims[self._shard_dimension] % self._number_of_shards) != 0: raise ValueError("shape %s cannot be sharded %d ways along dimension %d" % (shape.as_list(), self._number_of_shards, diff --git a/tensorflow/contrib/tpu/python/tpu/util.py b/tensorflow/contrib/tpu/python/tpu/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b8ea307d8900cf1b6d1e6e808d0b9ede26f86490 --- /dev/null +++ b/tensorflow/contrib/tpu/python/tpu/util.py @@ -0,0 +1,31 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =================================================================== + +"""Utilities for the functionalities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import six + + +def check_positive_integer(value, name): + """Checks whether `value` is a positive integer.""" + if not isinstance(value, six.integer_types): + raise TypeError('{} must be int, got {}'.format(name, type(value))) + + if value <= 0: + raise ValueError('{} must be positive, got {}'.format(name, value)) diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD index 8e3d869a51c440e00059851f05f6ed2fe5558416..6139c1d5838c24414549b4e2bc4722175f2d1925 100644 --- a/tensorflow/contrib/training/BUILD +++ b/tensorflow/contrib/training/BUILD @@ -26,6 +26,7 @@ py_library( "python/training/resample.py", "python/training/sampling_ops.py", "python/training/sequence_queueing_state_saver.py", + "python/training/sgdr_learning_rate_decay.py", "python/training/training.py", "python/training/tuner.py", ], @@ -41,6 +42,7 @@ py_library( "//tensorflow/python:data_flow_ops", "//tensorflow/python:errors", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:layers_base", "//tensorflow/python:logging_ops", "//tensorflow/python:math_ops", "//tensorflow/python:parsing_ops", @@ -111,6 +113,7 @@ py_test( "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:random_ops", + "//tensorflow/python:sparse_ops", "//tensorflow/python:sparse_tensor", "//tensorflow/python:string_ops", "//tensorflow/python:training", @@ -125,9 +128,12 @@ py_test( srcs = ["python/training/feeding_queue_runner_test.py"], srcs_version = "PY2AND3", deps = [ - ":training_py", "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_test_lib", + "//tensorflow/python:framework_ops", + "//tensorflow/python:session", + "//tensorflow/python:training", + "//tensorflow/python/estimator:inputs_queues", + "//third_party/py/numpy", ], ) @@ -139,7 +145,6 @@ py_test( deps = [ ":training_py", "//tensorflow/python:client_testlib", - "@six_archive//:six", ], ) @@ -243,12 +248,12 @@ py_test( "//tensorflow/contrib/metrics:metrics_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", - "//tensorflow/python:client", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:random_seed", + "//tensorflow/python:session", "//tensorflow/python:state_ops", "//tensorflow/python:summary", "//tensorflow/python:training", @@ -263,12 +268,14 @@ py_test( srcs = ["python/training/training_test.py"], shard_count = 3, srcs_version = "PY2AND3", + tags = ["notsan"], deps = [ ":training_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/layers:layers_py", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:gradients", "//tensorflow/python:math_ops", "//tensorflow/python:platform", "//tensorflow/python:random_seed", diff --git a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py index f6237872cce9be57809b12f8f5067646f328cb96..2a0ef0e6b3750b4f0464f1f4390819e1fc2c7872 100644 --- a/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py +++ b/tensorflow/contrib/training/python/training/batch_sequences_with_states_test.py @@ -30,6 +30,7 @@ from tensorflow.python.framework import sparse_tensor from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops +from tensorflow.python.ops import sparse_ops from tensorflow.python.ops import string_ops from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -527,6 +528,50 @@ class PaddingTest(test.TestCase): self.assertTrue( math_ops.reduce_all(math_ops.equal(val, padded_seq[key])).eval()) + def testPaddingOnlySparse(self): + ind1 = np.array([[0], [2]]) + val1 = np.array([3, 4]) + shape1 = np.array([4]) + + ind2 = np.array([[1], [2]]) + val2 = np.array([9, 12]) + shape2 = np.array([5]) + + with ops.Graph().as_default() as g, self.test_session(graph=g): + sp_tensor1 = sparse_tensor.SparseTensor( + indices=array_ops.constant(ind1, dtypes.int64), + values=array_ops.constant(val1, dtypes.int64), + dense_shape=array_ops.constant(shape1, dtypes.int64)) + sp_tensor2 = sparse_tensor.SparseTensor( + indices=array_ops.constant(ind2, dtypes.int64), + values=array_ops.constant(val2, dtypes.int64), + dense_shape=array_ops.constant(shape2, dtypes.int64)) + + sp_tensor1_expected = sparse_tensor.SparseTensor( + indices=sp_tensor1.indices, + values=sp_tensor1.values, + dense_shape=[8]) + sp_tensor2_expected = sparse_tensor.SparseTensor( + indices=sp_tensor2.indices, + values=sp_tensor2.values, + dense_shape=[8]) + + sequences = { + "key_1": sp_tensor1, + "key_2": sp_tensor2, + } + _, padded_seq = sqss._padding(sequences, 4) + + expected_padded_seq = { + "key_1": sp_tensor1_expected, + "key_2": sp_tensor2_expected, + } + + for key, val in expected_padded_seq.items(): + self.assertAllEqual( + sparse_ops.sparse_tensor_to_dense(val).eval(), + sparse_ops.sparse_tensor_to_dense(padded_seq[key]).eval()) + class SparseTensorReConstructionTest(test.TestCase): diff --git a/tensorflow/contrib/training/python/training/bucket_ops.py b/tensorflow/contrib/training/python/training/bucket_ops.py index 5523cc375fc20dc167fee0eaa6f1682dc1892c3f..95fbc50cba73b25b748c31ecd443eb19c0b6fc8a 100644 --- a/tensorflow/contrib/training/python/training/bucket_ops.py +++ b/tensorflow/contrib/training/python/training/bucket_ops.py @@ -31,6 +31,7 @@ from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util +from tensorflow.python.layers import utils from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import data_flow_ops @@ -47,7 +48,6 @@ _dtypes = input_py._dtypes _store_sparse_tensors = input_py._store_sparse_tensors _validate_keep_input = input_py._validate_keep_input _shapes = input_py._shapes -_smart_cond = input_py._smart_cond _which_queue = input_py._which_queue # pylint: enable=protected-access @@ -239,7 +239,7 @@ def bucket(tensors, ] return control_flow_ops.group(*enqueues, name="group_enqueues") - maybe_enqueue = _smart_cond( + maybe_enqueue = utils.smart_cond( keep_input, enqueue_which, control_flow_ops.no_op) diff --git a/tensorflow/contrib/training/python/training/hparam.py b/tensorflow/contrib/training/python/training/hparam.py index 119fa3824bd77724471768980783e105d5595c4b..391899b34f90be25e10450ebf4e285ed2d39446f 100644 --- a/tensorflow/contrib/training/python/training/hparam.py +++ b/tensorflow/contrib/training/python/training/hparam.py @@ -25,6 +25,7 @@ import six from tensorflow.contrib.training.python.training import hparam_pb2 from tensorflow.python.framework import ops from tensorflow.python.util import compat +from tensorflow.python.util import deprecation # Define the regular expression for parsing a single clause of the input # (delimited by commas). A legal clause looks like: @@ -138,7 +139,7 @@ def _process_list_value(name, parse_fn, var_type, m_dict, values, def parse_values(values, type_map): - """Parses hyperparameter values from a string into a python map.. + """Parses hyperparameter values from a string into a python map. `values` is a string containing comma-separated `name=value` pairs. For each pair, the value of the hyperparameter named `name` is set to @@ -470,24 +471,29 @@ class HParams(object): type_map[name] = param_type values_map = parse_values(values, type_map) - return self.set_from_map(values_map) + return self.override_from_dict(values_map) - def set_from_map(self, values_map): + def override_from_dict(self, values_dict): """Override hyperparameter values, parsing new values from a dictionary. Args: - values_map: Dictionary of name:value pairs. + values_dict: Dictionary of name:value pairs. Returns: The `HParams` instance. Raises: - ValueError: If `values_map` cannot be parsed. + ValueError: If `values_dict` cannot be parsed. """ - for name, value in values_map.items(): + for name, value in values_dict.items(): self.set_hparam(name, value) return self + @deprecation.deprecated(None, 'Use `override_from_dict`.') + def set_from_map(self, values_map): + """DEPRECATED. Use override_from_dict.""" + return self.override_from_dict(values_dict=values_map) + def set_model_structure(self, model_structure): self._model_structure = model_structure @@ -515,7 +521,7 @@ class HParams(object): ValueError: If `values_json` cannot be parsed. """ values_map = json.loads(values_json) - return self.set_from_map(values_map) + return self.override_from_dict(values_map) def values(self): """Return the hyperparameter values as a Python dictionary. @@ -526,6 +532,9 @@ class HParams(object): """ return {n: getattr(self, n) for n in self._hparam_types.keys()} + def __contains__(self, key): + return key in self._hparam_types + def __str__(self): return str(sorted(self.values().items())) diff --git a/tensorflow/contrib/training/python/training/hparam_test.py b/tensorflow/contrib/training/python/training/hparam_test.py index b01116a2139f76bab2e6219048c7c1aec013e626..f54514cefd39cab93e5c3a34786a6bb751b97704 100644 --- a/tensorflow/contrib/training/python/training/hparam_test.py +++ b/tensorflow/contrib/training/python/training/hparam_test.py @@ -32,6 +32,11 @@ class HParamsTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'Unknown hyperparameter'): hparams.parse('xyz=123') + def testContains(self): + hparams = hparam.HParams(foo=1) + self.assertTrue('foo' in hparams) + self.assertFalse('bar' in hparams) + def testSomeValues(self): hparams = hparam.HParams(aaa=1, b=2.0, c_c='relu6') self.assertDictEqual({'aaa': 1, 'b': 2.0, 'c_c': 'relu6'}, hparams.values()) @@ -93,11 +98,11 @@ class HParamsTest(test.TestCase): def testSetFromMap(self): hparams = hparam.HParams(a=1, b=2.0, c='tanh') - hparams.set_from_map({'a': -2, 'c': 'identity'}) + hparams.override_from_dict({'a': -2, 'c': 'identity'}) self.assertDictEqual({'a': -2, 'c': 'identity', 'b': 2.0}, hparams.values()) hparams = hparam.HParams(x=1, b=2.0, d=[0.5]) - hparams.set_from_map({'d': [0.1, 0.2, 0.3]}) + hparams.override_from_dict({'d': [0.1, 0.2, 0.3]}) self.assertDictEqual({'d': [0.1, 0.2, 0.3], 'x': 1, 'b': 2.0}, hparams.values()) diff --git a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py index 778cf985cada74458ff8022b3af56f1047bf46b2..72231948856b38edd3d022a99a62e6d4c8c5649e 100644 --- a/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py +++ b/tensorflow/contrib/training/python/training/sequence_queueing_state_saver.py @@ -1596,7 +1596,7 @@ def _padding(sequences, num_unroll): else: # Only have SparseTensors sparse_lengths = [value.dense_shape[0] for value in sequences_dict.values() if isinstance(value, sparse_tensor.SparseTensor)] - length = math_ops.maximum(sparse_lengths) + length = math_ops.reduce_max(math_ops.to_int32(sparse_lengths)) unroll = array_ops.constant(num_unroll) padded_length = length + ((unroll - (length % unroll)) % unroll) diff --git a/tensorflow/contrib/verbs/README.md b/tensorflow/contrib/verbs/README.md index da5f2b0223bc6698e750ebbc3307d70ee1535478..dcb390b0a5e343157dfd04ef8b18b7f723da27e0 100644 --- a/tensorflow/contrib/verbs/README.md +++ b/tensorflow/contrib/verbs/README.md @@ -1,4 +1,4 @@ -## How to compile and use RDMA-enabled TensorFlow +## How to compile, use and configure RDMA-enabled TensorFlow 1. Follow the regular TF compilation instructions. During configure step, if you want ibverbs based RDMA support, answer yes to this question: ```Do you wish to build TensorFlow with VERBS-RDMA support [y/N]``` @@ -7,6 +7,18 @@ ```server = tf.train.Server(cluster, job_name="local", task_index=0, protocol='grpc+verbs') # default protocol is 'grpc'``` +3. RDMA configuration is done by setting the following environment variables: + * **RDMA_DEVICE**: The RDMA device name to be used. If not defined by user, a default device with an active port will be set if exists. + * **RDMA_DEVICE_PORT**: The port within the selected device. Not relevant if RDMA_DEVICE is not defined. If not defined by user, a default active port will be set if exists. + * **RDMA_GID_INDEX**: The GID index of the port. If not defined by user, a default suitable GID index will be set (RoCEV2 is favourable as default). + * **RDMA_QP_PKEY_INDEX**: The Pkey for the QP. If not defined by user, the default value is 0. + * **RDMA_QP_QUEUE_DEPTH**: TX/RX queue size for the QP. If not defined by user, the default value is 1024. + * **RDMA_QP_TIMEOUT**: The retransmission timeout for QPs. If not defined by user, the default value is 14. + * **RDMA_QP_RETRY_COUNT**: Number of retransmission for QPs. If not defined by user, the default value is 7. + * **RDMA_QP_SL**: Service level configuration for QOS and ECN, valid values are 0-7. If not defined by user, the default value is 0. + * **RDMA_QP_MTU**: MTU configuration for the QPs. If not defined by user, the default value is active MTU from query_port. + * **RDMA_TRAFFIC_CLASS**: Traffic class configuration for QP, in case of DSCP trust level QoS configuration. If not defined by user, the default value is 0. For more info see [HowTo Configure Trust state on Mellanox Adapters](https://community.mellanox.com/docs/DOC-2866). + ## Overview The design is based on TensorFlow r1.0. An RDMA path is added between servers for tensor transfer (weights, gradients, etc). The existing GRPC path remains and is responsible for "administrative" tasks, such as setting up the RDMA path, exchanging computation graphs, etc. diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc index a1fbea57dd1202c1a22e6b3570e9378555fe3498..cff765d1e832e5a593462283444d7c4ed7831636 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.cc @@ -43,21 +43,21 @@ VerbsService::Stub::Stub( const std::shared_ptr< ::grpc::ChannelInterface>& channel) : channel_(channel), rpcmethod_GetRemoteAddress_(grpcVerbsService_method_names[0], - ::grpc::internal::RpcMethod::NORMAL_RPC, + ::grpc::RpcMethod::NORMAL_RPC, channel) {} ::grpc::Status VerbsService::Stub::GetRemoteAddress( ::grpc::ClientContext* context, const GetRemoteAddressRequest& request, GetRemoteAddressResponse* response) { - return ::grpc::internal::BlockingUnaryCall( + return ::grpc::BlockingUnaryCall( channel_.get(), rpcmethod_GetRemoteAddress_, context, request, response); } VerbsService::AsyncService::AsyncService() { for (int i = 0; i < 1; ++i) { - AddMethod(new ::grpc::internal::RpcServiceMethod( + AddMethod(new ::grpc::RpcServiceMethod( grpcVerbsService_method_names[i], - ::grpc::internal::RpcMethod::NORMAL_RPC, + ::grpc::RpcMethod::NORMAL_RPC, nullptr)); ::grpc::Service::MarkMethodAsync(i); } diff --git a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h index 86431ca030c38c56155801202714ee4a49b764df..6e2bf86dac2aa84ff453aaefbfc57cd3ee8bc1fd 100644 --- a/tensorflow/contrib/verbs/grpc_verbs_service_impl.h +++ b/tensorflow/contrib/verbs/grpc_verbs_service_impl.h @@ -28,6 +28,15 @@ limitations under the License. #include "tensorflow/contrib/verbs/verbs_service.pb.h" namespace grpc { + +// ensure internal namespace exists +namespace internal { +// bring in contents of external namespace +using namespace ::grpc; +} // namespace internal +// bring in contents of internal namespace +using namespace internal; + class CompletionQueue; class Channel; class RpcService; @@ -61,7 +70,7 @@ class VerbsService GRPC_FINAL { private: std::shared_ptr< ::grpc::ChannelInterface> channel_; - const ::grpc::internal::RpcMethod rpcmethod_GetRemoteAddress_; + const ::grpc::RpcMethod rpcmethod_GetRemoteAddress_; }; static std::unique_ptr NewStub( const std::shared_ptr< ::grpc::ChannelInterface>& channel, diff --git a/tensorflow/contrib/verbs/rdma.cc b/tensorflow/contrib/verbs/rdma.cc index 26e18b28aabd0db6c3c7091fca96aa30f39c73a2..331943a3ef059329a28372edbfd2f2ffc0931f58 100644 --- a/tensorflow/contrib/verbs/rdma.cc +++ b/tensorflow/contrib/verbs/rdma.cc @@ -17,6 +17,7 @@ limitations under the License. #include "tensorflow/contrib/verbs/rdma.h" #include +#include #include "tensorflow/contrib/verbs/verbs_util.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/common_runtime/dma_helper.h" @@ -33,6 +34,8 @@ limitations under the License. namespace tensorflow { +#define RoCE_V2 "RoCE v2" + namespace { // hash name to 32-bit integer uint32_t NameHash(const string& name) { @@ -66,16 +69,337 @@ string MessageTypeToString(RdmaMessageType rmt) { } } // namespace -ibv_context* open_default_device() { +// Function to get environment variable +// Args: +// var_name - the name of the environmental variable +// Returns: +// string with it's value or empty string if not set +string get_env_var(char const* var_name) { + char const* var_temp = getenv(var_name); + + return (var_temp == NULL) ? string() : string(var_temp); +} + +// Function to open device +// Args: +// ibv_dev device to open +// Returns: +// context of the opened device +ibv_context* open_device(ibv_device* ibv_dev) { + ibv_context* context = ibv_open_device(ibv_dev); + + CHECK(context) << "Open context failed for " << ibv_get_device_name(ibv_dev); + return context; +} + +// Function to count the number of active ports for device +// Args: +// device - to check active ports +// Returns: +// number of active ports of the given device +int get_dev_active_port_count(ibv_device* device) { + ibv_device_attr device_att; + ibv_port_attr port_attr; + ibv_context* context = NULL; + int rc, port_index, active_ports = 0; + + context = ibv_open_device(device); + CHECK(context) << "Open context failed for " << ibv_get_device_name(device); + rc = ibv_query_device(context, &device_att); + CHECK(!rc) << "Failed to query the device"; + + for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { + rc = ibv_query_port(context, port_index, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_index; + if (port_attr.state == IBV_PORT_ACTIVE) { + active_ports++; + } + } + ibv_close_device(context); + return active_ports; +} + +// Function to set device. If RDMA_DEVICE not set, search for device with active +// port. +// Fails if more than one device with active port was found. +// Returns: +// device to use +ibv_device* set_device() { ibv_device** dev_list; - ibv_device* ib_dev; - dev_list = ibv_get_device_list(NULL); + int dev_num, device_index, device_to_open = 0; + int num_devs_with_active_port = 0; + string env_p_rdma_device, str_port_num; + + dev_list = ibv_get_device_list(&dev_num); CHECK(dev_list) << "No InfiniBand device found"; - ib_dev = dev_list[0]; - CHECK(ib_dev) << "No InfiniBand device found"; - ibv_context* context = ibv_open_device(ib_dev); - CHECK(context) << "Open context failed for " << ibv_get_device_name(ib_dev); - return context; + + env_p_rdma_device = get_env_var("RDMA_DEVICE"); + if (!env_p_rdma_device.empty()) { + for (device_index = 0; device_index < dev_num; device_index++) { + if (!env_p_rdma_device.compare( + ibv_get_device_name(dev_list[device_index]))) { + CHECK(get_dev_active_port_count(dev_list[device_index]) != 0) + << "Device " << ibv_get_device_name(dev_list[device_index]) + << " has no active ports"; + return dev_list[device_index]; + } + } + // check validity of input device + CHECK(false) << "The device " << env_p_rdma_device << " wasn't found"; + } else { + // set default device + str_port_num = get_env_var("RDMA_DEVICE_PORT"); + CHECK(str_port_num.empty()) + << "RDMA_DEVICE should be provided if RDMA_DEVICE_PORT is set by user"; + for (device_index = 0; device_index < dev_num; device_index++) { + // get port_num + if (get_dev_active_port_count(dev_list[device_index]) > 0) { + num_devs_with_active_port++; + CHECK(num_devs_with_active_port <= 1) << ". More than one device with " + "active port in the system. " + "Please enter RDMA_DEVICE"; + // found device with at least 1 active port + device_to_open = device_index; + } + } + CHECK(num_devs_with_active_port > 0) + << "There is no active port in the system"; + return dev_list[device_to_open]; + } + CHECK(false) << "No device was set!"; + return NULL; // never happens +} + +// Function to set port for device. +// If RDMA_DEVICE_PORT not set, first active port of the device will be set. +// Args: +// context of the device +// Returns: +// port to use +uint8_t set_port(ibv_context* context) { + uint8_t port_num = 0; //0 is illegal port number + string str_port_num; + ibv_device_attr device_att; + ibv_port_attr port_attr; + int rc, port_index; + + rc = ibv_query_device(context, &device_att); + CHECK(!rc) << "Failed to query the device\n"; + + str_port_num = get_env_var("RDMA_DEVICE_PORT"); + // user defined port + if (!str_port_num.empty()) { + port_num = stoi(str_port_num); + CHECK(port_num > 0) << "RDMA_DEVICE_PORT should be positive"; + CHECK(port_num <= device_att.phys_port_cnt) << "RDMA_DEVICE_PORT should be " + "less or equal to amount of " + "available ports"; + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + // check if port id active + CHECK(port_attr.state == IBV_PORT_ACTIVE) + << "Selected RDMA_DEVICE_PORT is not active"; + } + // set default port + else { + for (port_index = 1; port_index <= device_att.phys_port_cnt; port_index++) { + rc = ibv_query_port(context, port_index, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_index; + if (port_attr.state == IBV_PORT_ACTIVE) { + port_num = port_index; + break; + } + } + CHECK_GT(port_num, 0) << "No active ports"; + } + return port_num; +} + +// Function read from sysfs file +// Args: +// dir - directory +// file - file +// buff - buffer for the result +// size - buffer size +// Returns: +// number of bytes were read or -1 if failed +int read_sysfs_file(const char* dir, const char* file, char* buf, size_t size) { + char* path; + int fd; + int len; + + if (asprintf(&path, "%s/%s", dir, file) < 0) return -1; + + fd = open(path, O_RDONLY); + if (fd < 0) { + free(path); + return -1; + } + + len = read(fd, buf, size); + + close(fd); + free(path); + + if (len > 0 && buf[len - 1] == '\n') buf[--len] = '\0'; + + return len; +} + +// Function to check if GID index support RoCE V2 +// Args: +// context - device context +// port_num - port number +// index - GID index +// Returns: +// if GID supports RoCE V2 - true, otherwise - false. +bool is_gid_type_roce_v2(ibv_context* context, uint8_t port_num, + uint8_t index) { + char name[32]; + char buff[41]; + + snprintf(name, sizeof(name), "ports/%d/gid_attrs/types/%d", port_num, index); + if (read_sysfs_file(context->device->ibdev_path, name, buff, sizeof(buff)) <= + 0) { + return false; + } + return !strcmp(buff, RoCE_V2); +} + +// Function to set GID index. +// If the port link is IB, no GID index should be selected. +// If Ethernet but RDMA_GID_INDEX not set gid index that supports +// RoCE V2 will be chosen(fails if more then one IP is configured) +// Args: +// context - device context +// port_num - port number +// Returns: +// GID index to use +uint8_t set_gid(uint8_t port_num, ibv_context* context) { + ibv_port_attr port_attr; + string gid_str; + int rc, i, gids_num = 0, v2_ip_num = 0; + union ibv_gid gid; + uint8_t gid_index = 0; + + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + + for (i = 0; i < port_attr.gid_tbl_len; i++) { + rc = ibv_query_gid(context, port_num, i, &gid); + CHECK(!rc) << "Failed to query gid to port " << (int)port_num << " index " + << i; + if (gid.global.interface_id) { + gids_num++; + if (gid.global.subnet_prefix == 0 && + is_gid_type_roce_v2(context, port_num, i)) { + if (v2_ip_num == 0) { + // can be overwritten by RDMA_GID_INDEX later + gid_index = i; + } + v2_ip_num++; + } + } + } + switch (port_attr.link_layer) { + case(IBV_LINK_LAYER_ETHERNET) : + gid_str = get_env_var("RDMA_GID_INDEX"); + if (!gid_str.empty()) { + gid_index = stoi(gid_str); + CHECK(gid_index < gids_num) + << "RDMA_GID_INDEX should be less than GIDs amount" << gids_num; + } else { + CHECK(v2_ip_num <= 1) + << "More than one IP is available, please specify GID_INDEX"; + } + break; + case(IBV_LINK_LAYER_INFINIBAND) : // no need in GID index + break; + default: + LOG(INFO) << "Unknown port link layer. Currently supporting Ethernet and " + "InfiniBand only. "; + } + if (!is_gid_type_roce_v2(context, port_num, gid_index)) { + LOG(INFO) << "RoCE v2 is not configured for GID_INDEX " << (int)gid_index; + } + return gid_index; +} + +// set the default or environment value to the configuration parameter. +// Args: +// default_val- the default value for this parameter +// env_param- the environment parameter's name +// Returns: +// 32-bit value +uint32_t set_param(uint32_t default_val, const char* env_param) { + uint32_t val = default_val; + string val_s; + + val_s = get_env_var(env_param); + + if (!val_s.empty()) { + val = stoi(val_s); + } + return val; +} + +enum ibv_mtu set_mtu(uint8_t port_num, ibv_context* context) { + ibv_port_attr port_attr; + enum ibv_mtu mtu; + string mtu_s; + int rc, mtu_i; + + rc = ibv_query_port(context, port_num, &port_attr); + CHECK(!rc) << "Failed to query the port" << port_num; + + mtu_s = get_env_var("RDMA_MTU"); + + if (!mtu_s.empty()) { + mtu_i = stoi(mtu_s); + switch (mtu_i) { + case 256: + mtu = IBV_MTU_256; + break; + case 512: + mtu = IBV_MTU_512; + break; + case 1024: + mtu = IBV_MTU_1024; + break; + case 2048: + mtu = IBV_MTU_2048; + break; + case 4096: + mtu = IBV_MTU_4096; + break; + default: + CHECK(0) << "Error: MTU input value must be one of the following: 256, " + "512, 1024, 2048, 4096. MTU " << mtu << " is invalid\n"; + break; + } + CHECK(mtu < port_attr.active_mtu) + << "MTU configuration for the QPs is larger than active MTU"; + } else { + mtu = port_attr.active_mtu; + } + return mtu; +} + +RdmaParams params_init(ibv_context* context) { + RdmaParams params; + + params.port_num = set_port(context); + params.sgid_index = set_gid(params.port_num, context); + params.pkey_index = (uint8_t)set_param(PKEY_DEFAULT, "RDMA_PKEY"); + params.queue_depth = set_param(QUEUE_DEPTH_DEFAULT, "RDMA_QUEUE_DEPTH"); + params.timeout = (uint8_t)set_param(TIMEOUT_DEFAULT, "RDMA_TIMEOUT"); + params.retry_cnt = (uint8_t)set_param(RETRY_CNT_DEFAULT, "RDMA_RETRY_CNT"); + params.sl = (uint8_t)set_param(SL_DEFAULT, "RDMA_SL"); + CHECK(params.sl <= 7) << "SL value is " << (int)params.sl + << ". Valid values are 0-7."; + params.mtu = set_mtu(params.port_num, context); + params.traffic_class = set_param(TRAFFIC_CLASS, "RDMA_TRAFFIC_CLASS"); + return params; } ibv_pd* alloc_protection_domain(ibv_context* context) { @@ -85,7 +409,8 @@ ibv_pd* alloc_protection_domain(ibv_context* context) { } RdmaAdapter::RdmaAdapter(const WorkerEnv* worker_env) - : context_(open_default_device()), + : context_(open_device(set_device())), + params_(params_init(context_)), pd_(alloc_protection_domain(context_)), worker_env_(worker_env) { event_channel_ = ibv_create_comp_channel(context_); @@ -128,9 +453,9 @@ void RdmaAdapter::Process_CQ() { CHECK_GE(ne, 0); for (int i = 0; i < ne; ++i) { CHECK(wc_[i].status == IBV_WC_SUCCESS) - << "Failed status \n" - << ibv_wc_status_str(wc_[i].status) << " " << wc_[i].status << " " - << static_cast(wc_[i].wr_id) << " " << wc_[i].vendor_err; + << "Failed status \n" << ibv_wc_status_str(wc_[i].status) << " " + << wc_[i].status << " " << static_cast(wc_[i].wr_id) << " " + << wc_[i].vendor_err; if (wc_[i].opcode == IBV_WC_RECV_RDMA_WITH_IMM) { RdmaChannel* rc = reinterpret_cast(wc_[i].wr_id); // put back a recv wr. @@ -242,8 +567,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, memset(&attr, 0, sizeof(ibv_qp_init_attr)); attr.send_cq = adapter_->cq_; attr.recv_cq = adapter_->cq_; - attr.cap.max_send_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; - attr.cap.max_recv_wr = RdmaAdapter::MAX_CONCURRENT_WRITES; + attr.cap.max_send_wr = adapter_->params_.queue_depth; + attr.cap.max_recv_wr = adapter_->params_.queue_depth; attr.cap.max_send_sge = 1; attr.cap.max_recv_sge = 1; attr.qp_type = IBV_QPT_RC; @@ -257,8 +582,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = 0; - attr.port_num = 1; + attr.pkey_index = adapter_->params_.pkey_index; + attr.port_num = adapter_->params_.port_num; attr.qp_access_flags = IBV_ACCESS_LOCAL_WRITE | IBV_ACCESS_REMOTE_WRITE; int mask = @@ -269,13 +594,15 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // Local address { struct ibv_port_attr attr; - CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &attr)) + CHECK( + !ibv_query_port(adapter_->context_, adapter_->params_.port_num, &attr)) << "Query port"; self_.lid = attr.lid; self_.qpn = qp_->qp_num; self_.psn = static_cast(random::New64()) & 0xffffff; union ibv_gid gid; - CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid)) + CHECK(!ibv_query_gid(adapter_->context_, adapter_->params_.port_num, + adapter_->params_.sgid_index, &gid)) << "Query gid"; self_.snp = gid.global.subnet_prefix; self_.iid = gid.global.interface_id; @@ -284,7 +611,7 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name, // create message and ack buffers, then initialize the tables. { const string buffer_names[] = {"tx_message_buffer", "rx_message_buffer", - "tx_ack_buffer", "rx_ack_buffer"}; + "tx_ack_buffer", "rx_ack_buffer"}; tx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[0]); rx_message_buffer_ = new RdmaMessageBuffer(this, buffer_names[1]); tx_ack_buffer_ = new RdmaAckBuffer(this, buffer_names[2]); @@ -345,7 +672,7 @@ void RdmaChannel::SetRemoteAddress(const RdmaAddress& ra, bool override) { void RdmaChannel::Recv() { struct ibv_recv_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; struct ibv_recv_wr* bad_wr; CHECK(!ibv_post_recv(qp_, &wr, &bad_wr)) << "Failed to post recv"; } @@ -479,11 +806,9 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { struct ibv_qp_attr attr; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTR; - struct ibv_port_attr port_attr; - CHECK(!ibv_query_port(adapter_->context_, (uint8_t)1, &port_attr)) - << "Query port failed"; + // This assumes both QP's ports are configured with the same MTU - attr.path_mtu = port_attr.active_mtu; + attr.path_mtu = adapter_->params_.mtu; attr.dest_qp_num = remoteAddr.qpn; attr.rq_psn = remoteAddr.psn; attr.max_dest_rd_atomic = 1; @@ -494,30 +819,32 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) { attr.ah_attr.grh.flow_label = 0; attr.ah_attr.grh.hop_limit = 255; attr.ah_attr.dlid = remoteAddr.lid; - attr.ah_attr.sl = 0; + attr.ah_attr.sl = adapter_->params_.sl; attr.ah_attr.src_path_bits = 0; - attr.ah_attr.port_num = 1; + attr.ah_attr.port_num = adapter_->params_.port_num; + attr.ah_attr.grh.sgid_index = adapter_->params_.sgid_index; + attr.ah_attr.grh.traffic_class = adapter_->params_.traffic_class; int r; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_AV | IBV_QP_PATH_MTU | - IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | - IBV_QP_MAX_DEST_RD_ATOMIC | - IBV_QP_MIN_RNR_TIMER))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_AV | + IBV_QP_PATH_MTU | + IBV_QP_DEST_QPN | IBV_QP_RQ_PSN | + IBV_QP_MAX_DEST_RD_ATOMIC | + IBV_QP_MIN_RNR_TIMER))) << "QP to Ready to Receive " << r; memset(&attr, 0, sizeof(ibv_qp_attr)); attr.qp_state = IBV_QPS_RTS; attr.sq_psn = self_.psn; - attr.timeout = 14; - attr.retry_cnt = 7; + attr.timeout = adapter_->params_.timeout; + attr.retry_cnt = adapter_->params_.retry_cnt; attr.rnr_retry = 7; /* infinite */ attr.max_rd_atomic = 1; - CHECK(!(r = ibv_modify_qp(qp_, &attr, - IBV_QP_STATE | IBV_QP_TIMEOUT | IBV_QP_RETRY_CNT | - IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | - IBV_QP_MAX_QP_RD_ATOMIC))) + CHECK(!(r = ibv_modify_qp(qp_, &attr, IBV_QP_STATE | IBV_QP_TIMEOUT | + IBV_QP_RETRY_CNT | + IBV_QP_RNR_RETRY | IBV_QP_SQ_PSN | + IBV_QP_MAX_QP_RD_ATOMIC))) << "QP to Ready to Send " << r; connected_ = true; @@ -604,7 +931,7 @@ void RdmaBuffer::Write(uint32_t imm_data, size_t buffer_size) { struct ibv_send_wr wr; memset(&wr, 0, sizeof(wr)); - wr.wr_id = (uint64_t)this; + wr.wr_id = (uint64_t) this; wr.sg_list = &list; wr.num_sge = 1; wr.opcode = IBV_WR_RDMA_WRITE_WITH_IMM; @@ -699,9 +1026,9 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( TensorProto proto; if (src_dev->tensorflow_gpu_device_info() && (!send_args.alloc_attrs.on_host())) { - CHECK(send_args.device_context) - << "send dev name: " << src_dev->name() - << " gpu_info: " << src_dev->tensorflow_gpu_device_info(); + CHECK(send_args.device_context) << "send dev name: " << src_dev->name() + << " gpu_info: " + << src_dev->tensorflow_gpu_device_info(); if (can_memcpy) { AllocatorAttributes host_alloc_attrs; @@ -727,8 +1054,8 @@ Rendezvous::DoneCallback RdmaTensorBuffer::getRecvTensorCallback( // aync instead GPUUtil::SetProtoFromGPU( in, src_dev, send_args.device_context, &proto, is_dead, - [this, proto, buffer_size, key, in, step_id, key_with_step_id, - is_dead, send_args, recv_args](const Status& s) mutable { + [this, proto, buffer_size, key, in, step_id, key_with_step_id, + is_dead, send_args, recv_args](const Status& s) mutable { CHECK(s.ok()) << "copy proto from gpu sync"; auto tensor_bytes = proto.ByteSize(); buffer_size += tensor_bytes; diff --git a/tensorflow/contrib/verbs/rdma.h b/tensorflow/contrib/verbs/rdma.h index e1e07db776467c5b604f610bbc907d363edae139..52d92a7c5bb6f21e2449e06792d8d40c9bcbf9bd 100644 --- a/tensorflow/contrib/verbs/rdma.h +++ b/tensorflow/contrib/verbs/rdma.h @@ -36,7 +36,24 @@ limitations under the License. #include "tensorflow/core/platform/mutex.h" namespace tensorflow { - +#define PKEY_DEFAULT 0 +#define QUEUE_DEPTH_DEFAULT 1024 +#define TIMEOUT_DEFAULT 14 +#define RETRY_CNT_DEFAULT 7 +#define SL_DEFAULT 0 +#define TRAFFIC_CLASS 0 + +struct RdmaParams { + uint8_t port_num; + uint8_t sgid_index; + uint8_t pkey_index; + uint32_t queue_depth; + uint8_t timeout; + uint8_t retry_cnt; + uint8_t sl; + enum ibv_mtu mtu; + uint8_t traffic_class; +}; // structure to save the address of remote channels. struct RdmaAddress { uint32_t lid; @@ -50,9 +67,20 @@ struct RemoteMR { uint64_t remote_addr; uint32_t rkey; }; -enum BufferStatus { none, idle, busy }; -enum Location { local, remote }; -enum BufferType { ACK, MESSAGE, TENSOR }; +enum BufferStatus { + none, + idle, + busy +}; +enum Location { + local, + remote +}; +enum BufferType { + ACK, + MESSAGE, + TENSOR +}; enum RdmaMessageType { RDMA_MESSAGE_ACK, RDMA_MESSAGE_BUFFER_IDLE, @@ -84,6 +112,8 @@ class RdmaAdapter { protected: static const int MAX_CONCURRENT_WRITES = 1000; ibv_context* context_; + // RDMA configuration parameters + RdmaParams params_; // ibverbs protection domain ibv_pd* pd_; // Completion event channel, to wait for work completions @@ -183,7 +213,7 @@ class RdmaBuffer { } void FreeBuffer(); void EnqueueItem(string Item); - virtual void SendNextItem(){}; + virtual void SendNextItem() {}; void CreateCPUBuffer(size_t size, bool lock = true); void SetRemoteMR(RemoteMR rmi, bool override); uint32_t LookupBufferIndex(const string& buffer_name) { diff --git a/tensorflow/contrib/xla_tf_graph/BUILD b/tensorflow/contrib/xla_tf_graph/BUILD deleted file mode 100644 index 4a3a2de9b5e58cfab2e6f8de5c6789f1cbcebde7..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/BUILD +++ /dev/null @@ -1,67 +0,0 @@ -# Description: -# contains parts of TensorFlow that are experimental or unstable and which are not supported. - -package( - default_visibility = ["//visibility:public"], -) - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -load("//tensorflow:tensorflow.bzl", "tf_cc_test") - -filegroup( - name = "all_files", - srcs = glob( - ["**/*"], - exclude = [ - "**/METADATA", - "**/OWNERS", - ], - ), -) - -cc_library( - name = "xla_tf_graph_util", - srcs = [ - "xla_tf_graph_util.cc", - ], - hdrs = [ - "xla_tf_graph_util.h", - ], - deps = [ - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:status_macros", - "//tensorflow/compiler/xla/client", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/core:core_cpu", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - ], -) - -tf_cc_test( - name = "xla_tf_graph_util_test", - srcs = ["xla_tf_graph_util_test.cc"], - linkstatic = 1, - tags = ["nomac"], # b/63908145 - deps = [ - ":xla_tf_graph_util", - "//tensorflow/cc:cc_ops", - "//tensorflow/cc:function_ops", - "//tensorflow/cc:scope", - "//tensorflow/compiler/jit:xla_cpu_jit", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/xla:shape_util", - "//tensorflow/compiler/xla/client:client_library", - "//tensorflow/compiler/xla/service:hlo_module_config", - "//tensorflow/core:core_cpu_internal", - "//tensorflow/core:framework_internal", - "//tensorflow/core:ops", - "//tensorflow/core:tensorflow", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/kernels:cwise_op", - ], -) diff --git a/tensorflow/contrib/xla_tf_graph/README.md b/tensorflow/contrib/xla_tf_graph/README.md deleted file mode 100644 index a374189e813107bcf3fe71032d4baf16b3d164a2..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/README.md +++ /dev/null @@ -1,8 +0,0 @@ -# Xla Tf Graph - -## Description - -This module contains utilities to treat xla representation as tf graph to support mobile SOC experiments and leverage tf tools. - -Maintainers: -- Satoshi Kataoka (satok@google.com, github.com/satok16) diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc deleted file mode 100644 index 302aa6457ab08a30bca9c28a5f162331111c4b77..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.cc +++ /dev/null @@ -1,247 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" - -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/core/platform/protobuf.h" - -namespace tensorflow { -namespace xla_tf_graph { - -namespace { - -constexpr const char* const GRAPH_NAME = "xla_tf_graph"; -constexpr const char* const NODE_NAME_PREFIX = "xla"; - -Status ConvertPrimitiveTypeToDataType(const xla::PrimitiveType p_type, - DataType* d_type) { - switch (p_type) { - case xla::PRED: - *d_type = DT_BOOL; - return Status::OK(); - case xla::S8: - *d_type = DT_INT8; - return Status::OK(); - case xla::S16: - *d_type = DT_INT16; - return Status::OK(); - case xla::S32: - *d_type = DT_INT32; - return Status::OK(); - case xla::S64: - *d_type = DT_INT64; - return Status::OK(); - case xla::U8: - *d_type = DT_UINT8; - return Status::OK(); - case xla::U16: - *d_type = DT_UINT16; - return Status::OK(); - case xla::F16: - *d_type = DT_HALF; - return Status::OK(); - case xla::F32: - *d_type = DT_FLOAT; - return Status::OK(); - case xla::F64: - *d_type = DT_DOUBLE; - return Status::OK(); - default: - return errors::InvalidArgument( - "Unsupported PrimitiveType in ConvertPrimitiveTypeToDataType ", - xla::PrimitiveType_Name(p_type)); - } -} - -Status ConvertXlaShapeToTensorShapeType(const xla::Shape& xla_shape, - std::vector* tensor_shapes, - std::vector* data_types) { - switch (xla_shape.element_type()) { - case xla::TUPLE: { - for (const xla::Shape& element_shape : xla_shape.tuple_shapes()) { - if (element_shape.element_type() == xla::TUPLE) { - return errors::InvalidArgument("Nested tuple is not allowed."); - } - TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( - element_shape, tensor_shapes, data_types)); - } - return Status::OK(); - } - case xla::PRED: - case xla::S8: - case xla::S16: - case xla::S32: - case xla::S64: - case xla::U8: - case xla::U16: - case xla::U32: - case xla::U64: - case xla::F16: - case xla::F32: - case xla::F64: { - TensorShape shape; - DataType type; - TF_RETURN_IF_ERROR( - ConvertPrimitiveTypeToDataType(xla_shape.element_type(), &type)); - for (const int64& dim : xla_shape.dimensions()) { - shape.AddDim(dim); - } - tensor_shapes->emplace_back(shape); - data_types->emplace_back(type); - return Status::OK(); - } - default: - return errors::InvalidArgument( - "Unsupported PrimitiveType in ConvertXlaShapeToTensorShapeType ", - xla::PrimitiveType_Name(xla_shape.element_type())); - } -} - -string BuildXlaNodeName(const xla::OperationRequest& operation_request, - const string& xla_op_type, const string& suffix) { - const string name = strings::StrCat( - NODE_NAME_PREFIX, "/", operation_request.output_handle().handle(), "/", - xla_op_type); - if (suffix.empty()) { - return name; - } else { - return strings::StrCat(name, "/", suffix); - } -} - -string BuildXlaNodeName(const xla::OperationRequest& operation_request, - const string& xla_op_type) { - return BuildXlaNodeName(operation_request, xla_op_type, ""); -} - -string BuildXlaNodeOp(const protobuf::Message& msg, const string& suffix) { - return strings::StrCat(msg.GetDescriptor()->name(), "/", suffix); -} - -string BuildXlaNodeOp(const protobuf::Message& msg) { - return BuildXlaNodeOp(msg, ""); -} - -Status ConvertOpRequestToXlaNode(const xla::OperationRequest& operation_request, - XlaNode* xla_node) { - const xla::OpRequest& op_request = operation_request.request(); - switch (op_request.op_case()) { - case xla::OpRequest::kBinaryOpRequest: { - const xla::BinaryOpRequest& op = op_request.binary_op_request(); - xla_node->op_type = - BuildXlaNodeOp(op, xla::BinaryOperation_Name(op.binop())); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - xla_node->input_ids.emplace_back(std::make_tuple(op.lhs().handle(), 0)); - xla_node->input_ids.emplace_back(std::make_tuple(op.rhs().handle(), 0)); - for (const int64& dim : op.broadcast_dimensions()) { - xla_node->broadcast_dimensions.emplace_back(dim); - } - break; - } - case xla::OpRequest::kParameterRequest: { - const xla::ParameterRequest& op = op_request.parameter_request(); - xla_node->op_type = BuildXlaNodeOp(op, ""); - xla_node->name = - BuildXlaNodeName(operation_request, xla_node->op_type, op.name()); - break; - } - case xla::OpRequest::kVariadicOpRequest: { - const xla::VariadicOpRequest& op = op_request.variadic_op_request(); - xla_node->op_type = - BuildXlaNodeOp(op, xla::VariadicOperation_Name(op.varop())); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - for (const xla::ComputationDataHandle& handle : op.operands()) { - xla_node->input_ids.emplace_back(std::make_tuple(handle.handle(), 0)); - } - break; - } - case xla::OpRequest::kGetTupleElementRequest: { - const xla::GetTupleElementRequest& op = - op_request.get_tuple_element_request(); - xla_node->op_type = BuildXlaNodeOp(op); - xla_node->name = BuildXlaNodeName(operation_request, xla_node->op_type); - xla_node->input_ids.emplace_back( - std::make_tuple(op.operand().handle(), op.index())); - break; - } - default: - // TODO(satok): Implement all possible cases. - LOG(FATAL) << "Op request: " << op_request.op_case() - << " is not supported yet."; - break; - } - - CHECK(!xla_node->name.empty()); - CHECK(!xla_node->op_type.empty()); - - TF_RETURN_IF_ERROR(ConvertXlaShapeToTensorShapeType( - operation_request.output_shape(), &xla_node->output_shapes, - &xla_node->output_data_types)); - return Status::OK(); -} - -void SetupXlaCpuClient(std::unique_ptr* flib_def, - std::unique_ptr* compiler) { - xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def->reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - - // Setup compiler options - XlaCompiler::Options options; - DeviceType device_type(DEVICE_CPU_XLA_JIT); - options.device_type = &device_type; - options.flib_def = flib_def->get(); - options.client = client; - compiler->reset(new XlaCompiler(options)); -} - -} // namespace - -xla::StatusOr> -ConvertTfGraphToXlaSessionModule(const std::vector& args, - std::unique_ptr graph) { - CHECK(graph); - - std::unique_ptr flib_def; - std::unique_ptr compiler; - - SetupXlaCpuClient(&flib_def, &compiler); - - // Compile graph and build computation - XlaCompiler::CompilationResult result; - TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), GRAPH_NAME, - std::move(graph), args, &result)); - - return result.computation->Snapshot(); -} - -xla::StatusOr> -ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module) { - std::unordered_map xla_nodes; - for (const auto& operation_request : session_module.entry().requests()) { - XlaNode xla_node; - TF_RETURN_IF_ERROR( - ConvertOpRequestToXlaNode(operation_request.second, &xla_node)); - xla_nodes.emplace(operation_request.first, xla_node); - } - return std::move(xla_nodes); -} - -} // namespace xla_tf_graph -} // namespace tensorflow diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h deleted file mode 100644 index e635290851f7e5d078d98d845e7488fc3cd94049..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h +++ /dev/null @@ -1,72 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ -#define TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ - -#include - -#include "tensorflow/compiler/tf2xla/xla_compiler.h" -#include "tensorflow/compiler/xla/client/client.h" -#include "tensorflow/compiler/xla/status_macros.h" -#include "tensorflow/core/framework/function.h" -#include "tensorflow/core/graph/graph.h" -#include "tensorflow/core/platform/macros.h" - -namespace tensorflow { -namespace xla_tf_graph { - -// A set of utilities to handle xla computation requests. -// These utilities help developers leverage existing tools to work with -// xla computations, also provide a way to support TensorFlow ops by -// implementing xla computations so that they can do experiments on their -// specialized environments. - -// A structure to represent typed attributes of TensorFlow graph node. -// This structure contains op specific attributes as members so that -// we can treat them explicitly. -struct XlaNode { - // Unique node name - string name; - // Op type of xla computation - string op_type; - // List of pair of unique id and port of input node. - // We store this value instead - // of node name in order not to wait for all XlaNodes to be constructed. - std::vector> input_ids; - // Oputput shapes - std::vector output_shapes; - // Output data types - std::vector output_data_types; - - //--------------------------- - // Op specific attributes - // #xla::OpRequest::kBinaryOpRequest - std::vector broadcast_dimensions; -}; - -// Convert a tf graph to a xla session module -xla::StatusOr> -ConvertTfGraphToXlaSessionModule(const std::vector& args, - std::unique_ptr graph); - -// Convert a xla session module to a map to XlaNode from unique id -xla::StatusOr> -ConvertXlaSessionModuleToXlaNodes(const xla::SessionModule& session_module); - -} // namespace xla_tf_graph -} // namespace tensorflow - -#endif // TENSORFLOW_CONTRIB_XLA_TF_GRAPH_XLA_TF_GRAPH_UTIL_H_ diff --git a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc b/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc deleted file mode 100644 index 144269303ee140bb7a9a30133a5d88b41b4f4273..0000000000000000000000000000000000000000 --- a/tensorflow/contrib/xla_tf_graph/xla_tf_graph_util_test.cc +++ /dev/null @@ -1,134 +0,0 @@ -/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/contrib/xla_tf_graph/xla_tf_graph_util.h" -#include "tensorflow/cc/framework/scope.h" -#include "tensorflow/cc/ops/function_ops.h" -#include "tensorflow/cc/ops/standard_ops.h" -#include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "tensorflow/compiler/xla/client/client_library.h" -#include "tensorflow/compiler/xla/service/hlo_module_config.h" -#include "tensorflow/compiler/xla/shape_util.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace xla_tf_graph { - -static std::unique_ptr BuildAddGraph() { - Scope scope = Scope::NewRootScope().ExitOnError(); - auto a = ops::_Arg(scope.WithOpName("A"), DT_INT32, 0); - auto b = ops::_Arg(scope.WithOpName("B"), DT_INT32, 1); - // See tf2xla/kernels/binary_ops.cc - auto c = ops::Add(scope.WithOpName("C"), a, b); - auto d = ops::_Retval(scope.WithOpName("D"), c, 0); - std::unique_ptr graph(new Graph(OpRegistry::Global())); - TF_CHECK_OK(scope.ToGraph(graph.get())); - return graph; -} - -static std::vector BuildAddGraphArguments() { - // Builds a description of the arguments. - std::vector args(2); - args[0].kind = XlaCompiler::Argument::kParameter; - args[0].type = DT_INT32; - // Difference of dimension will add extra broadcast_dimensions. - // broadcast_dimension generates an additional HloInstruction - // in user_computation.cc - args[0].shape = xla::ShapeUtil::MakeShape(xla::S32, {2, 2}); - args[1].kind = XlaCompiler::Argument::kParameter; - args[1].type = DT_INT32; - args[1].shape = xla::ShapeUtil::MakeShape(xla::S32, {2}); - return args; -} - -// CAVEAT: Debug purpose only. -// This function dumps a protobuf string format of HloModule. -static void DumpHloGraphForDebug(const std::vector& args, - std::unique_ptr graph) { - std::unique_ptr flib_def; - std::unique_ptr flr; - std::unique_ptr compiler; - - xla::Client* client = xla::ClientLibrary::LocalClientOrDie(); - XlaOpRegistry::RegisterCompilationKernels(); - - FunctionDefLibrary flib; - flib_def.reset(new FunctionLibraryDefinition(OpRegistry::Global(), flib)); - - // Compiles the graph. - XlaCompiler::Options options; - DeviceType device_type("XLA_CPU_JIT"); - options.device_type = &device_type; - options.client = client; - options.flib_def = flib_def.get(); - compiler.reset(new XlaCompiler(options)); - - // Compile graph - XlaCompiler::CompilationResult result; - TF_CHECK_OK(compiler->CompileGraph(XlaCompiler::CompileOptions(), "dump", - std::move(graph), args, &result)); - - // Convert to hlo - xla::Computation& computation = *result.computation; - - xla::Service* service( - static_cast(xla::ClientLibrary::GetXlaService( - static_cast(client)->platform()))); - const xla::ComputationTracker& computation_tracker = - service->computation_tracker(); - - auto user_computation_status = - computation_tracker.Resolve(computation.handle()); - TF_CHECK_OK(user_computation_status.status()); - auto user_computation = user_computation_status.ConsumeValueOrDie(); - xla::VersionedComputationHandle versioned_handle = - user_computation->GetVersionedHandle(); - std::unique_ptr hlo_module = - std::move(computation_tracker - .BuildHloModule(versioned_handle, xla::HloModuleConfig()) - .ValueOrDie()); - VLOG(1) << "--- DUMP HLO ---"; - VLOG(1) << hlo_module->ToString(); -} - -TEST(XlaTfGraphUtil, ConvertTfGraphToSessionModule) { - // Builds a description of the arguments. - std::vector args = BuildAddGraphArguments(); - std::unique_ptr graph = BuildAddGraph(); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr session_module, - ConvertTfGraphToXlaSessionModule(args, std::move(graph))); - - ASSERT_EQ(4, session_module->entry().requests_size()); - - VLOG(1) << "--- DUMP ---"; - VLOG(1) << session_module->DebugString(); - DumpHloGraphForDebug(args, BuildAddGraph()); -} - -TEST(XlaTfGraphUtil, ConvertXlaSessionModuleToXlaNodes) { - std::vector args = BuildAddGraphArguments(); - std::unique_ptr graph = BuildAddGraph(); - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr session_module, - ConvertTfGraphToXlaSessionModule(args, std::move(graph))); - TF_ASSERT_OK_AND_ASSIGN(auto xla_nodes, - ConvertXlaSessionModuleToXlaNodes(*session_module)); - EXPECT_EQ(session_module->entry().requests_size(), xla_nodes.size()); -} - -} // namespace xla_tf_graph -} // namespace tensorflow diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index b18b3cb123c4e056a38c5751b76bf04a9490e187..9530af637ef953c293472d926281de77cf626752 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -163,6 +163,7 @@ CORE_PROTO_SRCS = [ "framework/function.proto", "framework/graph.proto", "framework/graph_transfer_info.proto", + "framework/iterator.proto", "framework/kernel_def.proto", "framework/log_memory.proto", "framework/node_def.proto", @@ -248,6 +249,14 @@ tf_proto_library( visibility = ["//visibility:public"], ) +# Minimal lib to detect plafrom +cc_library( + name = "lib_platform", + hdrs = [ + "platform/platform.h", + ], +) + # Minimal lib so that tools used for mobile compilation # don't have to depend on lib/platformlib. cc_library( @@ -445,6 +454,7 @@ tf_cuda_library( "util/mirror_pad_mode.h", "util/padding.h", "util/port.h", + "util/reffed_status_callback.h", "util/saved_tensor_slice_util.h", "util/sparse/group_iterator.h", "util/sparse/sparse_tensor.h", @@ -509,6 +519,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":lib", + ":lib_internal", ":op_gen_overrides_proto_cc", ":protos_all_cc", ], @@ -650,14 +661,15 @@ cc_library( ":image_ops_op_lib", ":io_ops_op_lib", ":linalg_ops_op_lib", - ":lookup_ops_op_lib", ":logging_ops_op_lib", + ":lookup_ops_op_lib", ":math_ops_op_lib", ":nn_ops_op_lib", ":no_op_op_lib", ":parsing_ops_op_lib", ":random_ops_op_lib", ":remote_fused_graph_ops_op_lib", + ":resource_variable_ops_op_lib", ":script_ops_op_lib", ":sdca_ops_op_lib", ":sendrecv_ops_op_lib", @@ -779,6 +791,7 @@ cc_library( "//tensorflow/core/kernels:dataset_ops", "//tensorflow/core/kernels:fake_quant_ops", "//tensorflow/core/kernels:function_ops", + "//tensorflow/core/kernels:histogram_op", "//tensorflow/core/kernels:image", "//tensorflow/core/kernels:io", "//tensorflow/core/kernels:linalg", @@ -887,6 +900,7 @@ cc_library( ":test", "//tensorflow/cc:scope", "//tensorflow/core/kernels:constant_op", + "//tensorflow/core/kernels:ops_testutil", "//tensorflow/core/kernels:ops_util", ], ) @@ -1279,6 +1293,13 @@ tf_pyclif_proto_library( visibility = ["//visibility:public"], ) +tf_pyclif_proto_library( + name = "protobuf/meta_graph_pyclif", + proto_lib = ":protos_all_cc", + proto_srcfile = "protobuf/meta_graph.proto", + visibility = ["//visibility:public"], +) + # ----------------------------------------------------------------------------- # Internal targets @@ -1388,20 +1409,24 @@ LIB_INTERNAL_PUBLIC_HEADERS = tf_additional_lib_hdrs() + [ "platform/platform.h", "platform/protobuf_internal.h", "platform/setround.h", + "platform/snappy.h", "platform/tensor_coding.h", "platform/tracing.h", ] +# Replicated for lib_internal and lib_internal_impl. +LIB_INTERNAL_DEFINES = (tf_additional_lib_defines() + [ + "TF_USE_SNAPPY", + ] + tf_additional_verbs_lib_defines() + + tf_additional_mpi_lib_defines() + + tf_additional_gdr_lib_defines()) + cc_library( name = "lib_internal", srcs = LIB_INTERNAL_PRIVATE_HEADERS, hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), - defines = tf_additional_lib_defines() + [ - "SNAPPY", - ] + tf_additional_verbs_lib_defines() + - tf_additional_mpi_lib_defines() + - tf_additional_gdr_lib_defines(), + defines = LIB_INTERNAL_DEFINES, linkopts = select({ "//tensorflow:freebsd": [], "//tensorflow:windows": [], @@ -1455,6 +1480,7 @@ cc_library( ), hdrs = LIB_INTERNAL_PUBLIC_HEADERS, copts = tf_copts(), + defines = LIB_INTERNAL_DEFINES, deps = tf_additional_lib_deps() + [ ":lib_hash_crc32c_accelerate_internal", ":lib_proto_parsing", @@ -1764,6 +1790,7 @@ tf_cuda_library( ) + if_mkl( [ "//third_party/mkl:intel_binary_blob", + "@mkl_dnn//:mkl_dnn", ], ), alwayslink = 1, @@ -1924,11 +1951,12 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [ "common_runtime/visitable_allocator.h", "graph/gradients.h", "graph/quantize_training.h", -] +] + if_mkl(["graph/mkl_graph_util.h"]) tf_cuda_library( name = "core_cpu_impl", srcs = [ + "common_runtime/accumulate_n_optimizer.cc", "common_runtime/allocator_retry.cc", "common_runtime/bfc_allocator.cc", "common_runtime/build_graph_options.cc", @@ -2025,7 +2053,10 @@ tf_cuda_library( "//third_party/eigen3", "//tensorflow/core/kernels:required", ] + if_mkl( - ["//third_party/mkl:intel_binary_blob"], + [ + "//third_party/mkl:intel_binary_blob", + "@mkl_dnn//:mkl_dnn", + ], ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), alwayslink = 1, ) @@ -2109,6 +2140,7 @@ GPU_RUNTIME_HEADERS = [ "common_runtime/gpu/gpu_debug_allocator.h", "common_runtime/gpu/gpu_device.h", "common_runtime/gpu/gpu_init.h", + "common_runtime/gpu/gpu_managed_allocator.h", "common_runtime/gpu/gpu_stream_util.h", "common_runtime/gpu/gpu_util.h", "common_runtime/gpu/pool_allocator.h", @@ -2123,6 +2155,7 @@ tf_cuda_library( "common_runtime/gpu/gpu_debug_allocator.cc", "common_runtime/gpu/gpu_device.cc", "common_runtime/gpu/gpu_device_factory.cc", + "common_runtime/gpu/gpu_managed_allocator.cc", "common_runtime/gpu/gpu_stream_util.cc", "common_runtime/gpu/gpu_util.cc", "common_runtime/gpu/gpu_util_platform_specific.cc", @@ -2159,6 +2192,7 @@ tf_cuda_library( ":lib", ":lib_internal", ":protos_all_cc", + ":stream_executor", "//third_party/eigen3", ] + if_static([":gpu_runtime_impl"]), ) @@ -2241,7 +2275,6 @@ cc_library( "lib/io/block_builder.h", "lib/io/format.h", "lib/random/philox_random_test_utils.h", - "platform/snappy.h", ], deps = [ ":lib", @@ -2488,6 +2521,7 @@ tf_cc_test( srcs = ["framework/op_gen_lib_test.cc"], deps = [ ":op_gen_lib", + ":protos_all_cc", ":test", ":test_main", ], @@ -2575,6 +2609,7 @@ tf_cc_tests( "util/example_proto_helper_test.cc", "util/memmapped_file_system_test.cc", "util/presized_cuckoo_map_test.cc", + "util/reffed_status_callback_test.cc", "util/reporter_test.cc", "util/saved_tensor_slice_util_test.cc", "util/semver_test.cc", @@ -2611,6 +2646,7 @@ tf_cc_tests( "//tensorflow/cc:cc_ops_internal", "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", + "//tensorflow/cc:while_loop", "//tensorflow/core/kernels:ops_util", "//third_party/eigen3", ], @@ -2652,14 +2688,31 @@ tf_cc_tests( ], ) +tf_cc_test_mkl( + name = "mkl_runtime_tests", + size = "small", + srcs = ["common_runtime/mkl_cpu_allocator_test.cc"], + linkstatic = 1, + deps = [ + ":core", + ":core_cpu", + ":framework", + ":framework_internal", + ":test", + ":test_main", + ":testlib", + ], +) + tf_cc_test_mkl( name = "mkl_related_tests", size = "small", srcs = [ "graph/mkl_layout_pass_test.cc", "graph/mkl_tfconversion_pass_test.cc", + "util/mkl_util_test.cc", ], - linkstatic = tf_kernel_tests_linkstatic(), + linkstatic = 1, deps = [ ":core", ":core_cpu", @@ -2677,6 +2730,9 @@ tf_cc_test_mkl( "//tensorflow/cc:cc_ops", "//tensorflow/cc:scope", "//tensorflow/cc:sendrecv_ops", + "//tensorflow/core/kernels:ops_util", + "//third_party/eigen3", + ] + if_mkl([ "//tensorflow/core/kernels:mkl_aggregate_ops", "//tensorflow/core/kernels:mkl_concat_op", "//tensorflow/core/kernels:mkl_conv_op", @@ -2689,9 +2745,7 @@ tf_cc_test_mkl( "//tensorflow/core/kernels:mkl_relu_op", "//tensorflow/core/kernels:mkl_reshape_op", "//tensorflow/core/kernels:mkl_tfconv_op", - "//tensorflow/core/kernels:ops_util", - "//third_party/eigen3", - ], + ]), ) tf_cc_tests_gpu( @@ -2859,9 +2913,11 @@ tf_cc_test( ":test_main", ":testlib", "//tensorflow/cc:cc_ops", + "//tensorflow/cc:resource_variable_ops", "//tensorflow/cc:scope", "//tensorflow/core/kernels:array", "//tensorflow/core/kernels:math", + "//tensorflow/core/kernels:resource_variable_ops", "//third_party/eigen3", ], ) @@ -3311,6 +3367,41 @@ tf_cc_test( ], ) +filegroup( + name = "base_api_def", + data = glob(["api_def/base_api/*"]), +) + +filegroup( + name = "python_api_def", + data = glob(["api_def/python_api/*"]), +) + +tf_cc_test( + name = "api_test", + srcs = ["api_def/api_test.cc"], + data = [ + ":base_api_def", + "//tensorflow/cc:ops/op_gen_overrides.pbtxt", + ], + tags = [ + "manual", + "notap", + ], + deps = [ + ":framework", + ":framework_internal", + ":lib", + ":lib_internal", + ":lib_test_internal", + ":op_gen_lib", + ":op_gen_overrides_proto_cc", + ":ops", + ":protos_all_cc", + ":test", + ], +) + tf_cc_test_gpu( name = "gpu_tracer_test", size = "small", diff --git a/tensorflow/core/api_def/api_test.cc b/tensorflow/core/api_def/api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d95d958d5afaad58bdec82183be3d3a09cf4605d --- /dev/null +++ b/tensorflow/core/api_def/api_test.cc @@ -0,0 +1,340 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Test that verifies tensorflow/core/api_def/base_api/api_def*.pbtxt files +// are correct. If api_def*.pbtxt do not match expected contents, run +// tensorflow/core/api_def/base_api/update_api_def.sh script to update them. + +#include +#include +#include +#include +#include + +#include "tensorflow/core/framework/api_def.pb.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/op_gen_lib.h" +#include "tensorflow/core/framework/op_gen_overrides.pb.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/lib/strings/stringprintf.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/init_main.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/util/command_line_flags.h" + +namespace tensorflow { +namespace { +constexpr char kDefaultApiDefDir[] = + "tensorflow/core/api_def/base_api"; +constexpr char kOverridesFilePath[] = + "tensorflow/cc/ops/op_gen_overrides.pbtxt"; +constexpr char kApiDefFileFormat[] = "api_def_%s.pbtxt"; +constexpr char kApiDefFilePattern[] = "api_def_*.pbtxt"; + +void FillBaseApiDef(ApiDef* api_def, const OpDef& op) { + api_def->set_graph_op_name(op.name()); + // Add arg docs + for (auto& input_arg : op.input_arg()) { + if (!input_arg.description().empty()) { + auto* api_def_in_arg = api_def->add_in_arg(); + api_def_in_arg->set_name(input_arg.name()); + api_def_in_arg->set_description(input_arg.description()); + } + } + for (auto& output_arg : op.output_arg()) { + if (!output_arg.description().empty()) { + auto* api_def_out_arg = api_def->add_out_arg(); + api_def_out_arg->set_name(output_arg.name()); + api_def_out_arg->set_description(output_arg.description()); + } + } + // Add attr docs + for (auto& attr : op.attr()) { + if (!attr.description().empty()) { + auto* api_def_attr = api_def->add_attr(); + api_def_attr->set_name(attr.name()); + api_def_attr->set_description(attr.description()); + } + } + // Add docs + api_def->set_summary(op.summary()); + api_def->set_description(op.description()); +} + +// Checks if arg1 should be before arg2 according to ordering in args. +bool CheckArgBefore(const ApiDef::Arg* arg1, const ApiDef::Arg* arg2, + const protobuf::RepeatedPtrField& args) { + for (auto& arg : args) { + if (arg.name() == arg2->name()) { + return false; + } else if (arg.name() == arg1->name()) { + return true; + } + } + return false; +} + +// Checks if attr1 should be before attr2 according to ordering in op_def. +bool CheckAttrBefore(const ApiDef::Attr* attr1, const ApiDef::Attr* attr2, + const OpDef& op_def) { + for (auto& attr : op_def.attr()) { + if (attr.name() == attr2->name()) { + return false; + } else if (attr.name() == attr1->name()) { + return true; + } + } + return false; +} + +// Applies renames to args. +void ApplyArgOverrides( + protobuf::RepeatedPtrField* args, + const protobuf::RepeatedPtrField& renames, + const protobuf::RepeatedPtrField& op_args, + const string& op_name) { + for (auto& rename : renames) { + // First check if rename is valid. + bool valid = false; + for (const auto& op_arg : op_args) { + if (op_arg.name() == rename.from()) { + valid = true; + } + } + QCHECK(valid) << rename.from() << " is not a valid argument for " + << op_name; + bool found_arg = false; + // If Arg is already in ApiDef, just update it. + for (int i = 0; i < args->size(); ++i) { + auto* arg = args->Mutable(i); + if (arg->name() == rename.from()) { + arg->set_rename_to(rename.to()); + found_arg = true; + break; + } + } + if (!found_arg) { // not in ApiDef, add a new arg. + auto* new_arg = args->Add(); + new_arg->set_name(rename.from()); + new_arg->set_rename_to(rename.to()); + } + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(args->pointer_begin(), args->pointer_end(), + [&](ApiDef::Arg* arg1, ApiDef::Arg* arg2) { + return CheckArgBefore(arg1, arg2, op_args); + }); +} + +// Returns existing attribute with the given name if such +// attribute exists. Otherwise, adds a new attribute and returns it. +ApiDef::Attr* FindOrAddAttr(ApiDef* api_def, const string attr_name) { + // If Attr is already in ApiDef, just update it. + for (int i = 0; i < api_def->attr_size(); ++i) { + auto* attr = api_def->mutable_attr(i); + if (attr->name() == attr_name) { + return attr; + } + } + // Add a new Attr. + auto* new_attr = api_def->add_attr(); + new_attr->set_name(attr_name); + return new_attr; +} + +// Applies renames and default values to attributes. +void ApplyAttrOverrides(ApiDef* api_def, const OpGenOverride& op_override, + const OpDef& op_def) { + for (auto& attr_rename : op_override.attr_rename()) { + auto* attr = FindOrAddAttr(api_def, attr_rename.from()); + attr->set_rename_to(attr_rename.to()); + } + + for (auto& attr_default : op_override.attr_default()) { + auto* attr = FindOrAddAttr(api_def, attr_default.name()); + *(attr->mutable_default_value()) = attr_default.value(); + } + // We don't really need a specific order here right now. + // However, it is clearer if order follows OpDef. + std::sort(api_def->mutable_attr()->pointer_begin(), + api_def->mutable_attr()->pointer_end(), + [&](ApiDef::Attr* attr1, ApiDef::Attr* attr2) { + return CheckAttrBefore(attr1, attr2, op_def); + }); +} + +void ApplyOverridesToApiDef(ApiDef* api_def, const OpDef& op, + const OpGenOverride& op_override) { + // Fill ApiDef with data based on op and op_override. + // Set visibility + if (op_override.skip()) { + api_def->set_visibility(ApiDef_Visibility_SKIP); + } else if (op_override.hide()) { + api_def->set_visibility(ApiDef_Visibility_HIDDEN); + } + // Add endpoints + if (!op_override.rename_to().empty()) { + api_def->add_endpoint()->set_name(op_override.rename_to()); + } else if (!op_override.alias().empty()) { + api_def->add_endpoint()->set_name(op.name()); + } + + for (auto& alias : op_override.alias()) { + auto* endpoint = api_def->add_endpoint(); + endpoint->set_name(alias); + } + + ApplyArgOverrides(api_def->mutable_in_arg(), op_override.input_rename(), + op.input_arg(), api_def->graph_op_name()); + ApplyArgOverrides(api_def->mutable_out_arg(), op_override.output_rename(), + op.output_arg(), api_def->graph_op_name()); + ApplyAttrOverrides(api_def, op_override, op); +} + +// Get map from ApiDef file path to corresponding ApiDefs proto. +std::unordered_map GenerateApiDef( + const string& api_def_dir, const OpList& ops, + const OpGenOverrides& overrides) { + std::unordered_map name_to_override; + for (const auto& op_override : overrides.op()) { + name_to_override[op_override.name()] = op_override; + } + + std::unordered_map api_defs_map; + + for (const auto& op : ops.op()) { + CHECK(!op.name().empty()) + << "Encountered empty op name: %s" << op.DebugString(); + string file_path = io::JoinPath(api_def_dir, kApiDefFileFormat); + file_path = strings::Printf(file_path.c_str(), op.name().c_str()); + ApiDef* api_def = api_defs_map[file_path].add_op(); + FillBaseApiDef(api_def, op); + + if (name_to_override.find(op.name()) != name_to_override.end()) { + ApplyOverridesToApiDef(api_def, op, name_to_override[op.name()]); + } + } + return api_defs_map; +} + +// Reads golden ApiDef files and returns a map from file name to ApiDef file +// contents. +std::unordered_map GetGoldenApiDefs( + Env* env, const string& api_files_dir) { + std::vector matching_paths; + TF_CHECK_OK(env->GetMatchingPaths( + io::JoinPath(api_files_dir, kApiDefFilePattern), &matching_paths)); + + std::unordered_map file_path_to_api_def; + for (auto& file_path : matching_paths) { + string file_contents; + TF_CHECK_OK(ReadFileToString(env, file_path, &file_contents)); + file_path_to_api_def[file_path] = file_contents; + } + return file_path_to_api_def; +} + +void RunApiTest(bool update_api_def, const string& api_files_dir) { + // Read C++ overrides file + OpGenOverrides overrides; + Env* env = Env::Default(); + TF_EXPECT_OK(ReadTextProto(env, kOverridesFilePath, &overrides)); + + // Read all ops + OpList ops; + OpRegistry::Global()->Export(false, &ops); + const std::vector multi_line_fields = {"description"}; + + // Get expected ApiDefs + const auto new_api_defs_map = GenerateApiDef(api_files_dir, ops, overrides); + + bool updated_at_least_one_file = false; + const auto golden_api_defs_map = GetGoldenApiDefs(env, api_files_dir); + + for (auto new_api_entry : new_api_defs_map) { + const auto& file_path = new_api_entry.first; + const auto& golden_api_defs_str = golden_api_defs_map.at(file_path); + string new_api_defs_str = new_api_entry.second.DebugString(); + new_api_defs_str = PBTxtToMultiline(new_api_defs_str, multi_line_fields); + if (golden_api_defs_str == new_api_defs_str) { + continue; + } + if (update_api_def) { + std::cout << "Updating " << file_path << "..." << std::endl; + TF_EXPECT_OK(WriteStringToFile(env, file_path, new_api_defs_str)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ(golden_api_defs_str, new_api_defs_str) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + + for (const auto& golden_api_entry : golden_api_defs_map) { + const auto& file_path = golden_api_entry.first; + if (new_api_defs_map.find(file_path) == new_api_defs_map.end()) { + if (update_api_def) { + std::cout << "Deleting " << file_path << "..." << std::endl; + TF_EXPECT_OK(env->DeleteFile(file_path)); + updated_at_least_one_file = true; + } else { + EXPECT_EQ("", golden_api_entry.second) + << "To update golden API files, run " + << "tensorflow/core/api_def/update_api_def.sh."; + } + } + } + + if (update_api_def && !updated_at_least_one_file) { + std::cout << "Api def files are already up to date." << std::endl; + } +} + +TEST(ApiTest, GenerateBaseAPIDef) { RunApiTest(false, kDefaultApiDefDir); } +} // namespace +} // namespace tensorflow + +int main(int argc, char** argv) { + bool update_api_def = false; + tensorflow::string api_files_dir = tensorflow::kDefaultApiDefDir; + std::vector flag_list = { + tensorflow::Flag( + "update_api_def", &update_api_def, + "Whether to update tensorflow/core/api_def/base_api/api_def*.pbtxt " + "files if they differ from expected API."), + tensorflow::Flag("api_def_dir", &api_files_dir, + "Base directory of api_def*.pbtxt files.")}; + std::string usage = tensorflow::Flags::Usage(argv[0], flag_list); + bool parsed_values_ok = tensorflow::Flags::Parse(&argc, argv, flag_list); + if (!parsed_values_ok) { + std::cerr << usage << std::endl; + return 2; + } + if (update_api_def) { + tensorflow::port::InitMain(argv[0], &argc, &argv); + tensorflow::RunApiTest(update_api_def, api_files_dir); + return 0; + } + testing::InitGoogleTest(&argc, argv); + // Run tests + return RUN_ALL_TESTS(); +} diff --git a/tensorflow/core/api_def/base_api/api_def_Abort.pbtxt b/tensorflow/core/api_def/base_api/api_def_Abort.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..6dd923c512af8d38ec04ec1116cc5da1e97d7e92 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Abort.pbtxt @@ -0,0 +1,16 @@ +op { + graph_op_name: "Abort" + attr { + name: "error_msg" + description: <= 2." +} diff --git a/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..429a5e4434e011d1ba43847b9abf8877b4d41e7a --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_AdjustContrastv2.pbtxt @@ -0,0 +1,36 @@ +op { + graph_op_name: "AdjustContrastv2" + endpoint { + name: "AdjustContrast" + } + in_arg { + name: "images" + description: < [2.0132, 1.056] +``` + +@compatibility(numpy) +Equivalent to np.angle. +@end_compatibility +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_Any.pbtxt b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..09fd4e0b6036447dfe355ff56da29e276de62f2b --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_Any.pbtxt @@ -0,0 +1,42 @@ +op { + graph_op_name: "Any" + endpoint { + name: "Any" + } + endpoint { + name: "ReduceAny" + } + in_arg { + name: "input" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..974f3adc196129f9fe83d098c22dc3cd237263d6 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyFtrlV2.pbtxt @@ -0,0 +1,75 @@ +op { + graph_op_name: "ApplyFtrlV2" + in_arg { + name: "var" + description: < l1 else 0.0 +accum = accum_new +END +} diff --git a/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt new file mode 100644 index 0000000000000000000000000000000000000000..2f38ebd1b8c89a1a65368d3da38cead73225ada5 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ApplyGradientDescent.pbtxt @@ -0,0 +1,35 @@ +op { + graph_op_name: "ApplyGradientDescent" + in_arg { + name: "var" + description: < -1. +END + } + attr { + name: "scientific" + description: < -1. +END + } + attr { + name: "fill" + description: < -1. If empty, pads with spaces. +Another typical value is '0'. String cannot be longer than 1 character. +END + } + summary: "Converts each entry in the given tensor to strings. Supports many numeric" + description: <